001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.csv;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
028import org.tribuo.DataSource;
029import org.tribuo.Example;
030import org.tribuo.Output;
031import org.tribuo.data.columnar.ColumnarDataSource;
032import org.tribuo.data.columnar.ColumnarIterator;
033import org.tribuo.data.columnar.FieldProcessor;
034import org.tribuo.data.columnar.RowProcessor;
035import org.tribuo.provenance.ConfiguredDataSourceProvenance;
036
037import java.io.IOException;
038import java.io.InputStream;
039import java.io.PushbackInputStream;
040import java.io.PushbackReader;
041import java.io.Reader;
042import java.net.URI;
043import java.nio.file.Path;
044import java.nio.file.Paths;
045import java.time.Instant;
046import java.time.OffsetDateTime;
047import java.time.ZoneId;
048import java.util.HashMap;
049import java.util.Map;
050import java.util.Objects;
051
052/**
053 * A {@link DataSource} for loading separable data from a text file (e.g., CSV, TSV)
054 * and applying {@link FieldProcessor}s to it.
055 */
056public class CSVDataSource<T extends Output<T>> extends ColumnarDataSource<T> {
057
058    private URI dataFile;
059
060    @Config(mandatory = true,description="Path to the CSV file.")
061    private Path dataPath;
062
063    @Config(description="The CSV separator character.")
064    private char separator = CSVIterator.SEPARATOR;
065
066    @Config(description="The CSV quote character.")
067    private char quote = CSVIterator.QUOTE;
068
069    private ConfiguredDataSourceProvenance provenance;
070
071    /**
072     * For OLCUT.
073     */
074    private CSVDataSource() {}
075
076    /**
077     * Creates a CSVDataSource using the specified RowProcessor to process the data.
078     *
079     * <p>
080     *
081     * Uses ',' as the separator, '"' as the quote character, and '\' as the escape character.
082     * @param dataPath The Path to the data file.
083     * @param rowProcessor The row processor which converts a row into an {@link Example}.
084     * @param outputRequired Is the output required to exist in the data file.
085     */
086    public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired) {
087        this(dataPath,rowProcessor,outputRequired, CSVIterator.SEPARATOR, CSVIterator.QUOTE);
088    }
089
090    /**
091     * Creates a CSVDataSource using the specified RowProcessor to process the data.
092     *
093     * <p>
094     *
095     * Uses ',' as the separator, '"' as the quote character, and '\' as the escape character.
096     * @param dataFile A URI for the data file.
097     * @param rowProcessor The row processor which converts a row into an {@link Example}.
098     * @param outputRequired Is the output required to exist in the data file.
099     */
100    public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired) {
101        this(dataFile,rowProcessor,outputRequired, CSVIterator.SEPARATOR, CSVIterator.QUOTE);
102    }
103
104    /**
105     * Creates a CSVDataSource using the specified RowProcessor to process the data.
106     *
107     * <p>
108     *
109     * Uses '"' as the quote character, and '\' as the escape character.
110     * @param dataPath The Path to the data file.
111     * @param rowProcessor The row processor which converts a row into an {@link Example}.
112     * @param outputRequired Is the output required to exist in the data file.
113     * @param separator The separator character in the data file.
114     */
115    public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator) {
116        this(dataPath,rowProcessor,outputRequired,separator, CSVIterator.QUOTE);
117    }
118
119    /**
120     * Creates a CSVDataSource using the specified RowProcessor to process the data.
121     *
122     * <p>
123     *
124     * Uses '"' as the quote character, and '\' as the escape character.
125     * @param dataFile A URI for the data file.
126     * @param rowProcessor The row processor which converts a row into an {@link Example}.
127     * @param outputRequired Is the output required to exist in the data file.
128     * @param separator The separator character in the data file.
129     */
130    public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired, char separator) {
131        this(dataFile,rowProcessor,outputRequired,separator, CSVIterator.QUOTE);
132    }
133
134    /**
135     * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator and quote
136     * characters to read the input data file.
137     * @param dataFile A URI for the data file.
138     * @param rowProcessor The row processor which converts a row into an {@link Example}.
139     * @param outputRequired Is the output required to exist in the data file.
140     * @param separator The separator character in the data file.
141     * @param quote The quote character in the data file.
142     */
143    public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) {
144        this(dataFile, Paths.get(dataFile),rowProcessor,outputRequired,separator,quote);
145    }
146
147    /**
148     * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator and quote
149     * characters to read the input data file.
150     * @param dataPath The Path to the data file.
151     * @param rowProcessor The row processor which converts a row into an {@link Example}.
152     * @param outputRequired Is the output required to exist in the data file.
153     * @param separator The separator character in the data file.
154     * @param quote The quote character in the data file.
155     */
156    public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) {
157        this(dataPath.toUri(),dataPath,rowProcessor,outputRequired,separator,quote);
158    }
159
160    /**
161     * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator, quote
162     * characters to read the input data file.
163     * @param dataFile A URI for the data file.
164     * @param rowProcessor The row processor which converts a row into an {@link Example}.
165     * @param outputRequired Is the output required to exist in the data file.
166     * @param separator The separator character in the data file.
167     * @param quote The quote character in the data file.
168     */
169    private CSVDataSource(URI dataFile, Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) {
170        super(rowProcessor.getResponseProcessor().getOutputFactory(), rowProcessor, outputRequired);
171        this.dataPath = dataPath;
172        this.dataFile = dataFile;
173        this.separator = separator;
174        this.quote = quote;
175        this.provenance = new CSVDataSourceProvenance(this);
176    }
177
178    /**
179     * Used by the OLCUT configuration system, and should not be called by external code.
180     */
181    @Override
182    public void postConfig() {
183        this.dataFile = dataPath.toUri();
184        this.provenance = new CSVDataSourceProvenance(this);
185    }
186
187    @Override
188    public String toString() {
189        return "CSVDataSource(file=" + dataFile + ",rowProcessor="+rowProcessor.getDescription()+")";
190    }
191
192    @Override
193    public ColumnarIterator rowIterator() {
194        try {
195            return new CSVIterator(dataFile, separator, quote);
196        } catch (IOException e) {
197            throw new IllegalStateException("Failed to read data",e);
198        }
199    }
200
201    @Override
202    public ConfiguredDataSourceProvenance getProvenance() {
203        return provenance;
204    }
205
206    /**
207     * Removes a UTF-8 byte order mark if it exists.
208     * <p>
209     * Note Tribuo only supports UTF-8 inputs, so the other BOMs are not checked for.
210     * @param stream The stream to check.
211     * @return An input stream with the BOM consumed (if present).
212     * @throws IOException If the stream failed to read.
213     */
214    static InputStream removeBOM(InputStream stream) throws IOException {
215        PushbackInputStream pushbackStream = new PushbackInputStream(stream,3);
216        byte[] bomBytes = new byte[3];
217        int bytesRead = pushbackStream.read(bomBytes,0,3);
218        if (!((bomBytes[0] == (byte)0xEF) && (bomBytes[1] == (byte)0xBB) && (bomBytes[2] == (byte)0xBF))) {
219            pushbackStream.unread(bomBytes);
220        }
221        return pushbackStream;
222    }
223
224    /**
225     * Removes a UTF-8 byte order mark if it exists.
226     * <p>
227     * Note Tribuo only supports UTF-8 inputs, so the other BOMs are not checked for.
228     * @param reader The reader to check.
229     * @return A reader with the BOM consumed (if present).
230     * @throws IOException If the reader failed to read.
231     */
232    static Reader removeBOM(Reader reader) throws IOException {
233        PushbackReader pushbackStream = new PushbackReader(reader,1);
234        int bomChar = pushbackStream.read();
235        if (!(bomChar == 0xFEFF)) {
236            pushbackStream.unread(bomChar);
237        }
238        return pushbackStream;
239    }
240
241    /**
242     * Provenance for {@link CSVDataSource}.
243     */
244    public static class CSVDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
245        private static final long serialVersionUID = 1L;
246
247        private final DateTimeProvenance fileModifiedTime;
248        private final DateTimeProvenance dataSourceCreationTime;
249        private final HashProvenance sha256Hash;
250
251        <T extends Output<T>> CSVDataSourceProvenance(CSVDataSource<T> host) {
252            super(host,"DataSource");
253            this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.dataPath.toFile().lastModified()), ZoneId.systemDefault()));
254            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now());
255            this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.dataPath));
256        }
257
258        public CSVDataSourceProvenance(Map<String,Provenance> map) {
259            this(extractProvenanceInfo(map));
260        }
261
262        private CSVDataSourceProvenance(ExtractedInfo info) {
263            super(info);
264            this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME);
265            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
266            this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH);
267        }
268
269        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
270            Map<String,Provenance> configuredParameters = new HashMap<>(map);
271            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, CSVDataSourceProvenance.class.getSimpleName()).getValue();
272            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, CSVDataSourceProvenance.class.getSimpleName()).getValue();
273
274            Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
275            instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, CSVDataSourceProvenance.class.getSimpleName()));
276            instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, CSVDataSourceProvenance.class.getSimpleName()));
277            instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, CSVDataSourceProvenance.class.getSimpleName()));
278
279            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters);
280        }
281
282        @Override
283        public boolean equals(Object o) {
284            if (this == o) return true;
285            if (o == null || getClass() != o.getClass()) return false;
286            if (!super.equals(o)) return false;
287            CSVDataSourceProvenance pairs = (CSVDataSourceProvenance) o;
288            return fileModifiedTime.equals(pairs.fileModifiedTime) &&
289                    dataSourceCreationTime.equals(pairs.dataSourceCreationTime) &&
290                    sha256Hash.equals(pairs.sha256Hash);
291        }
292
293        @Override
294        public int hashCode() {
295            return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash);
296        }
297
298        @Override
299        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
300            Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues();
301
302            map.put(FILE_MODIFIED_TIME,fileModifiedTime);
303            map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime);
304            map.put(RESOURCE_HASH,sha256Hash);
305
306            return map;
307        }
308    }
309}