001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.csv;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance;
027import com.oracle.labs.mlrg.olcut.util.Pair;
028import org.tribuo.Example;
029import org.tribuo.MutableDataset;
030import org.tribuo.Output;
031import org.tribuo.OutputFactory;
032import org.tribuo.datasource.ListDataSource;
033import org.tribuo.impl.ArrayExample;
034import org.tribuo.provenance.DataSourceProvenance;
035import org.tribuo.provenance.OutputFactoryProvenance;
036
037import java.io.BufferedReader;
038import java.io.FileInputStream;
039import java.io.FileNotFoundException;
040import java.io.IOException;
041import java.io.InputStreamReader;
042import java.io.Reader;
043import java.net.URISyntaxException;
044import java.net.URL;
045import java.nio.charset.StandardCharsets;
046import java.nio.file.Files;
047import java.nio.file.Path;
048import java.time.OffsetDateTime;
049import java.util.ArrayList;
050import java.util.Arrays;
051import java.util.Collections;
052import java.util.HashMap;
053import java.util.HashSet;
054import java.util.Iterator;
055import java.util.List;
056import java.util.Map;
057import java.util.Objects;
058import java.util.Optional;
059import java.util.Set;
060import java.util.logging.Logger;
061
062/**
063 * Load a DataSource/Dataset from a CSV file.
064 * <p>
065 * The delimiter and quote characters are user controlled, so this class can parse TSVs,
066 * CSVs, semi-colon separated data and other types of single character delimiter separated data.
067 * <p>
068 * This class is a simple loader *only* for numerical CSV files with a String response field.
069 * If you need more complex processing, the response field isn't present, or you don't wish to
070 * use all of the columns as features then you should use {@link CSVDataSource} and build a
071 * {@link org.tribuo.data.columnar.RowProcessor} to cope with your specific input format.
072 * <p>
073 * CSVLoader is thread safe and immutable.
074 * @param <T> The type of the output generated.
075 */
076public class CSVLoader<T extends Output<T>> {
077
078    private static final Logger logger = Logger.getLogger(CSVLoader.class.getName());
079
080    private final char separator;
081    private final char quote;
082    private final OutputFactory<T> outputFactory;
083
084    /**
085     * Creates a CSVLoader using the supplied separator, quote and output factory.
086     * @param separator The separator character.
087     * @param quote The quote character.
088     * @param outputFactory The output factory.
089     */
090    public CSVLoader(char separator, char quote, OutputFactory<T> outputFactory) {
091        this.separator = separator;
092        this.quote = quote;
093        this.outputFactory = outputFactory;
094    }
095
096    /**
097     * Creates a CSVLoader using the supplied separator and output factory.
098     * Sets the quote to {@link CSVIterator#QUOTE}.
099     * @param separator The separator character.
100     * @param outputFactory The output factory.
101     */
102    public CSVLoader(char separator, OutputFactory<T> outputFactory) {
103        this(separator, CSVIterator.QUOTE, outputFactory);
104    }
105
106    /**
107     * Creates a CSVLoader using the supplied output factory.
108     * Sets the separator to {@link CSVIterator#SEPARATOR} and the quote to {@link CSVIterator#QUOTE}.
109     * @param outputFactory The output factory.
110     */
111    public CSVLoader(OutputFactory<T> outputFactory) {
112        this(CSVIterator.SEPARATOR, CSVIterator.QUOTE, outputFactory);
113    }
114
115    /**
116     * Loads a DataSource from the specified csv file then wraps it in a dataset.
117     *
118     * @param csvPath      The path to load.
119     * @param responseName The name of the response variable.
120     * @return A dataset containing the csv data.
121     * @throws IOException If the read failed.
122     */
123    public MutableDataset<T> load(Path csvPath, String responseName) throws IOException {
124        return new MutableDataset<>(loadDataSource(csvPath, responseName));
125    }
126
127    /**
128     * Loads a DataSource from the specified csv file then wraps it in a dataset.
129     *
130     * @param csvPath      The path to load.
131     * @param responseName The name of the response variable.
132     * @param header       The header of the CSV if it's not present in the file.
133     * @return A dataset containing the csv data.
134     * @throws IOException If the read failed.
135     */
136    public MutableDataset<T> load(Path csvPath, String responseName, String[] header) throws IOException {
137        return new MutableDataset<>(loadDataSource(csvPath, responseName, header));
138    }
139
140    /**
141     * Loads a DataSource from the specified csv file then wraps it in a dataset.
142     *
143     * @param csvPath       The path to load.
144     * @param responseNames The names of the response variables.
145     * @return A dataset containing the csv data.
146     * @throws IOException If the read failed.
147     */
148    public MutableDataset<T> load(Path csvPath, Set<String> responseNames) throws IOException {
149        return new MutableDataset<>(loadDataSource(csvPath, responseNames));
150    }
151
152    /**
153     * Loads a DataSource from the specified csv file then wraps it in a dataset.
154     *
155     * @param csvPath       The path to load.
156     * @param responseNames The names of the response variables.
157     * @param header        The header of the CSV if it's not present in the file.
158     * @return A dataset containing the csv data.
159     * @throws IOException If the read failed.
160     */
161    public MutableDataset<T> load(Path csvPath, Set<String> responseNames, String[] header) throws IOException {
162        return new MutableDataset<>(loadDataSource(csvPath, responseNames, header));
163    }
164
165    /**
166     * Loads a DataSource from the specified csv path.
167     *
168     * @param csvPath      The csv to load from.
169     * @param responseName The name of the response variable.
170     * @return A datasource containing the csv data.
171     * @throws IOException If the disk read failed.
172     */
173    public ListDataSource<T> loadDataSource(Path csvPath, String responseName) throws IOException {
174        return loadDataSource(csvPath, Collections.singleton(responseName));
175    }
176
177    /**
178     * Loads a DataSource from the specified csv path.
179     *
180     * @param csvPath      The csv to load from.
181     * @param responseName The name of the response variable.
182     * @return A datasource containing the csv data.
183     * @throws IOException If the disk read failed.
184     */
185    public ListDataSource<T> loadDataSource(URL csvPath, String responseName) throws IOException {
186        return loadDataSource(csvPath, Collections.singleton(responseName));
187    }
188
189    /**
190     * Loads a DataSource from the specified csv path.
191     *
192     * @param csvPath      The csv to load from.
193     * @param responseName The name of the response variable.
194     * @param header       The header of the CSV if it's not present in the file.
195     * @return A datasource containing the csv data.
196     * @throws IOException If the disk read failed.
197     */
198    public ListDataSource<T> loadDataSource(Path csvPath, String responseName, String[] header) throws IOException {
199        return loadDataSource(csvPath, Collections.singleton(responseName), header);
200    }
201
202    /**
203     * Loads a DataSource from the specified csv path.
204     *
205     * @param csvPath      The csv to load from.
206     * @param responseName The name of the response variable.
207     * @param header       The header of the CSV if it's not present in the file.
208     * @return A datasource containing the csv data.
209     * @throws IOException If the disk read failed.
210     */
211    public ListDataSource<T> loadDataSource(URL csvPath, String responseName, String[] header) throws IOException {
212        return loadDataSource(csvPath, Collections.singleton(responseName), header);
213    }
214
215    /**
216     * Loads a DataSource from the specified csv path.
217     *
218     * @param csvPath       The csv to load from.
219     * @param responseNames The names of the response variables.
220     * @return A datasource containing the csv data.
221     * @throws IOException If the disk read failed.
222     */
223    public ListDataSource<T> loadDataSource(Path csvPath, Set<String> responseNames) throws IOException {
224        return loadDataSource(csvPath, responseNames, null);
225    }
226
227    /**
228     * Loads a DataSource from the specified csv path.
229     *
230     * @param csvPath       The csv to load from.
231     * @param responseNames The names of the response variables.
232     * @return A datasource containing the csv data.
233     * @throws IOException If the disk read failed.
234     */
235    public ListDataSource<T> loadDataSource(URL csvPath, Set<String> responseNames) throws IOException {
236        return loadDataSource(csvPath, responseNames, null);
237    }
238
239    /**
240     * Loads a DataSource from the specified csv path.
241     *
242     * @param csvPath       The csv to load from.
243     * @param responseNames The names of the response variables.
244     * @param header        The header of the CSV if it's not present in the file.
245     * @return A datasource containing the csv data.
246     * @throws IOException If the disk read failed.
247     */
248    public ListDataSource<T> loadDataSource(Path csvPath, Set<String> responseNames, String[] header) throws IOException {
249        return loadDataSource(csvPath.toUri().toURL(),responseNames,header);
250    }
251
252    /**
253     * Loads a DataSource from the specified csv path.
254     *
255     * @param csvPath       The csv to load from.
256     * @param responseNames The names of the response variables.
257     * @param header        The header of the CSV if it's not present in the file.
258     * @return A datasource containing the csv data.
259     * @throws IOException If the disk read failed.
260     */
261    public ListDataSource<T> loadDataSource(URL csvPath, Set<String> responseNames, String[] header) throws IOException {
262        List<String> headerList = header == null ? Collections.emptyList() : Arrays.asList(header);
263        try (CSVIterator itr = new CSVIterator(csvPath.toURI(), separator, quote, headerList)) {
264            //
265            // CSVInteropProvenance constructor throws an exception on FileNotFound, so we include in the try/catch
266            DataSourceProvenance provenance = new CSVLoaderProvenance(
267                    csvPath,
268                    outputFactory.getProvenance(),
269                    String.join(",", responseNames), // If there are multiple responses, join them
270                    separator,
271                    quote
272            );
273            List<Example<T>> list = innerLoadFromCSV(itr, responseNames, csvPath.toString());
274            return new ListDataSource<>(list, outputFactory, provenance);
275        } catch (URISyntaxException e) {
276            throw new FileNotFoundException("Failed to read from URL '" + csvPath + "' as it could not be converted to a URI");
277        }
278    }
279
280    private List<Example<T>> innerLoadFromCSV(CSVIterator itr, Set<String> responseNames, String csvPath) {
281        validateResponseNames(responseNames, itr.getFields(), csvPath);
282        List<Example<T>> dataset = new ArrayList<>();
283        String responseName = responseNames.size() == 1 ? responseNames.iterator().next() : null;
284        //
285        // Create the examples.
286        while (itr.hasNext()) {
287            Map<String, String> row = itr.next().getRowData();
288            T label = (responseNames.size() == 1) ?
289                    buildOutput(responseName, row) :
290                    buildMultiOutput(responseNames, row);
291            ArrayExample<T> example = new ArrayExample<>(label);
292            for (Map.Entry<String, String> ent : row.entrySet()) {
293                String columnName = ent.getKey();
294                if (!responseNames.contains(columnName)) {
295                    //
296                    // If it's not a response, it's a feature
297                    double value = Double.parseDouble(ent.getValue());
298                    example.add(columnName, value);
299                }
300            }
301            dataset.add(example);
302        }
303        return dataset;
304    }
305
306    private static void validateResponseNames(Set<String> responseNames, List<String> headers, String csvPath) throws IllegalArgumentException {
307        if (responseNames.isEmpty()) {
308            throw new IllegalArgumentException("At least one response name must be specified, but responseNames is empty.");
309        }
310        //
311        // Validate that all the expected responses are included in the given header fields
312        Map<String, Boolean> responsesFound = new HashMap<>();
313        for (String response : responseNames) {
314            responsesFound.put(response, false);
315        }
316        for (String header : headers) {
317            if (responseNames.contains(header)) {
318                responsesFound.put(header, true);
319            }
320        }
321        for (Map.Entry<String, Boolean> kv : responsesFound.entrySet()) {
322            if (!kv.getValue()) {
323                throw new IllegalArgumentException(String.format("Response %s not found in file %s", kv.getKey(), csvPath));
324            }
325        }
326    }
327
328    private T buildOutput(String responseName, Map<String, String> row) {
329        String label = row.get(responseName);
330        T output = outputFactory.generateOutput(label);
331        return output;
332    }
333
334    /**
335     * Build a Output for a multi-output CSV file formatted like:
336     * <pre>
337     * Attr1,Attr2,...,Class1,Class2,Class3
338     * 1.0,0.5,...,true,true,false
339     * 1.0,0.5,...,true,false,false
340     * 1.0,0.5,...,false,true,true
341     * </pre>
342     * Or for multivariate regression,
343     * <pre>
344     * Attr1,Attr2,...,Var1,Var2,Var3
345     * 1.0,0.5,...,0.1,0.1,0.3
346     * 1.0,0.5,...,0.2,0.0,0.8
347     * </pre>
348     * @param responseNames The response dimension names
349     * @param row           The row to process.
350     */
351    private T buildMultiOutput(Set<String> responseNames, Map<String, String> row) {
352        Set<String> pairs = new HashSet<>();
353        for (String responseName : responseNames) {
354            String rawValue = row.get(responseName);
355            String pair = String.format("%s=%s", responseName, rawValue);
356            pairs.add(pair);
357        }
358        T output = outputFactory.generateOutput(pairs);
359        return output;
360    }
361
362    /**
363     * Provenance for CSVs loaded by {@link CSVLoader}.
364     */
365    public final static class CSVLoaderProvenance implements DataSourceProvenance {
366        private static final long serialVersionUID = 1L;
367
368        private static final String RESPONSE_NAME = "response-name";
369        private static final String SEP_PROV = "separator";
370        private static final String QUOTE_PROV = "quote";
371        private static final String PATH = "path";
372
373        private final StringProvenance className;
374        private final OutputFactoryProvenance factoryProvenance;
375
376        // In the multi-output case, the responseName will be a comma-separated list of response names
377        private final StringProvenance responseName;
378        private final CharProvenance separator;
379        private final CharProvenance quote;
380        private final URLProvenance path;
381        private final DateTimeProvenance fileModifiedTime;
382        private final HashProvenance sha256Hash;
383
384        CSVLoaderProvenance(URL path, OutputFactoryProvenance factoryProvenance, String responseName, char separator, char quote) {
385            this.className = new StringProvenance(CLASS_NAME, CSVLoader.class.getName());
386            this.factoryProvenance = factoryProvenance;
387            this.responseName = new StringProvenance(RESPONSE_NAME, responseName);
388            this.separator = new CharProvenance(SEP_PROV, separator);
389            this.quote = new CharProvenance(QUOTE_PROV, quote);
390            this.path = new URLProvenance(PATH, path);
391            Optional<OffsetDateTime> time = ProvenanceUtil.getModifiedTime(path);
392            this.fileModifiedTime = time.map(offsetDateTime -> new DateTimeProvenance(FILE_MODIFIED_TIME, offsetDateTime)).orElseGet(() -> new DateTimeProvenance(FILE_MODIFIED_TIME, OffsetDateTime.MIN));
393            this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE, RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, path));
394        }
395
396        public CSVLoaderProvenance(Map<String, Provenance> map) {
397            this.className = ObjectProvenance.checkAndExtractProvenance(map, CLASS_NAME, StringProvenance.class, CSVLoaderProvenance.class.getSimpleName());
398            this.factoryProvenance = ObjectProvenance.checkAndExtractProvenance(map, OUTPUT_FACTORY, OutputFactoryProvenance.class, CSVLoaderProvenance.class.getSimpleName());
399            this.responseName = ObjectProvenance.checkAndExtractProvenance(map, RESPONSE_NAME, StringProvenance.class, CSVLoaderProvenance.class.getSimpleName());
400            this.separator = ObjectProvenance.checkAndExtractProvenance(map, SEP_PROV, CharProvenance.class, CSVLoaderProvenance.class.getSimpleName());
401            this.quote = ObjectProvenance.checkAndExtractProvenance(map, QUOTE_PROV, CharProvenance.class, CSVLoaderProvenance.class.getSimpleName());
402            this.path = ObjectProvenance.checkAndExtractProvenance(map, PATH, URLProvenance.class, CSVLoaderProvenance.class.getSimpleName());
403            this.fileModifiedTime = ObjectProvenance.checkAndExtractProvenance(map, FILE_MODIFIED_TIME, DateTimeProvenance.class, CSVLoaderProvenance.class.getSimpleName());
404            this.sha256Hash = ObjectProvenance.checkAndExtractProvenance(map, RESOURCE_HASH, HashProvenance.class, CSVLoaderProvenance.class.getSimpleName());
405        }
406
407        @Override
408        public String getClassName() {
409            return className.getValue();
410        }
411
412        @Override
413        public Iterator<Pair<String, Provenance>> iterator() {
414            ArrayList<Pair<String, Provenance>> list = new ArrayList<>();
415
416            list.add(new Pair<>(CLASS_NAME, className));
417            list.add(new Pair<>(OUTPUT_FACTORY, factoryProvenance));
418            list.add(new Pair<>(RESPONSE_NAME, responseName));
419            list.add(new Pair<>(SEP_PROV, separator));
420            list.add(new Pair<>(QUOTE_PROV, quote));
421            list.add(new Pair<>(PATH, path));
422            list.add(new Pair<>(FILE_MODIFIED_TIME, fileModifiedTime));
423            list.add(new Pair<>(RESOURCE_HASH, sha256Hash));
424
425            return list.iterator();
426        }
427
428        @Override
429        public boolean equals(Object o) {
430            if (this == o) return true;
431            if (!(o instanceof CSVLoaderProvenance)) return false;
432            CSVLoaderProvenance pairs = (CSVLoaderProvenance) o;
433            return className.equals(pairs.className) &&
434                    factoryProvenance.equals(pairs.factoryProvenance) &&
435                    responseName.equals(pairs.responseName) &&
436                    separator.equals(pairs.separator) &&
437                    quote.equals(pairs.quote) &&
438                    path.equals(pairs.path) &&
439                    fileModifiedTime.equals(pairs.fileModifiedTime) &&
440                    sha256Hash.equals(pairs.sha256Hash);
441        }
442
443        @Override
444        public int hashCode() {
445            return Objects.hash(className, factoryProvenance, responseName, separator, quote, path, fileModifiedTime, sha256Hash);
446        }
447
448        @Override
449        public String toString() {
450            return generateString("CSV");
451        }
452    }
453
454}