001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.csv; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 028import org.tribuo.DataSource; 029import org.tribuo.Example; 030import org.tribuo.Output; 031import org.tribuo.data.columnar.ColumnarDataSource; 032import org.tribuo.data.columnar.ColumnarIterator; 033import org.tribuo.data.columnar.FieldProcessor; 034import org.tribuo.data.columnar.RowProcessor; 035import org.tribuo.provenance.ConfiguredDataSourceProvenance; 036 037import java.io.IOException; 038import java.io.InputStream; 039import java.io.PushbackInputStream; 040import java.io.PushbackReader; 041import java.io.Reader; 042import java.net.URI; 043import java.nio.file.Path; 044import java.nio.file.Paths; 045import java.time.Instant; 046import java.time.OffsetDateTime; 047import java.time.ZoneId; 048import java.util.HashMap; 049import java.util.Map; 050import java.util.Objects; 051 052/** 053 * A {@link DataSource} for loading separable data from a text file (e.g., CSV, TSV) 054 * and applying {@link FieldProcessor}s to it. 055 */ 056public class CSVDataSource<T extends Output<T>> extends ColumnarDataSource<T> { 057 058 private URI dataFile; 059 060 @Config(mandatory = true,description="Path to the CSV file.") 061 private Path dataPath; 062 063 @Config(description="The CSV separator character.") 064 private char separator = CSVIterator.SEPARATOR; 065 066 @Config(description="The CSV quote character.") 067 private char quote = CSVIterator.QUOTE; 068 069 private ConfiguredDataSourceProvenance provenance; 070 071 /** 072 * For OLCUT. 073 */ 074 private CSVDataSource() {} 075 076 /** 077 * Creates a CSVDataSource using the specified RowProcessor to process the data. 078 * 079 * <p> 080 * 081 * Uses ',' as the separator, '"' as the quote character, and '\' as the escape character. 082 * @param dataPath The Path to the data file. 083 * @param rowProcessor The row processor which converts a row into an {@link Example}. 084 * @param outputRequired Is the output required to exist in the data file. 085 */ 086 public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired) { 087 this(dataPath,rowProcessor,outputRequired, CSVIterator.SEPARATOR, CSVIterator.QUOTE); 088 } 089 090 /** 091 * Creates a CSVDataSource using the specified RowProcessor to process the data. 092 * 093 * <p> 094 * 095 * Uses ',' as the separator, '"' as the quote character, and '\' as the escape character. 096 * @param dataFile A URI for the data file. 097 * @param rowProcessor The row processor which converts a row into an {@link Example}. 098 * @param outputRequired Is the output required to exist in the data file. 099 */ 100 public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired) { 101 this(dataFile,rowProcessor,outputRequired, CSVIterator.SEPARATOR, CSVIterator.QUOTE); 102 } 103 104 /** 105 * Creates a CSVDataSource using the specified RowProcessor to process the data. 106 * 107 * <p> 108 * 109 * Uses '"' as the quote character, and '\' as the escape character. 110 * @param dataPath The Path to the data file. 111 * @param rowProcessor The row processor which converts a row into an {@link Example}. 112 * @param outputRequired Is the output required to exist in the data file. 113 * @param separator The separator character in the data file. 114 */ 115 public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator) { 116 this(dataPath,rowProcessor,outputRequired,separator, CSVIterator.QUOTE); 117 } 118 119 /** 120 * Creates a CSVDataSource using the specified RowProcessor to process the data. 121 * 122 * <p> 123 * 124 * Uses '"' as the quote character, and '\' as the escape character. 125 * @param dataFile A URI for the data file. 126 * @param rowProcessor The row processor which converts a row into an {@link Example}. 127 * @param outputRequired Is the output required to exist in the data file. 128 * @param separator The separator character in the data file. 129 */ 130 public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired, char separator) { 131 this(dataFile,rowProcessor,outputRequired,separator, CSVIterator.QUOTE); 132 } 133 134 /** 135 * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator and quote 136 * characters to read the input data file. 137 * @param dataFile A URI for the data file. 138 * @param rowProcessor The row processor which converts a row into an {@link Example}. 139 * @param outputRequired Is the output required to exist in the data file. 140 * @param separator The separator character in the data file. 141 * @param quote The quote character in the data file. 142 */ 143 public CSVDataSource(URI dataFile, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) { 144 this(dataFile, Paths.get(dataFile),rowProcessor,outputRequired,separator,quote); 145 } 146 147 /** 148 * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator and quote 149 * characters to read the input data file. 150 * @param dataPath The Path to the data file. 151 * @param rowProcessor The row processor which converts a row into an {@link Example}. 152 * @param outputRequired Is the output required to exist in the data file. 153 * @param separator The separator character in the data file. 154 * @param quote The quote character in the data file. 155 */ 156 public CSVDataSource(Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) { 157 this(dataPath.toUri(),dataPath,rowProcessor,outputRequired,separator,quote); 158 } 159 160 /** 161 * Creates a CSVDataSource using the specified RowProcessor to process the data, and the supplied separator, quote 162 * characters to read the input data file. 163 * @param dataFile A URI for the data file. 164 * @param rowProcessor The row processor which converts a row into an {@link Example}. 165 * @param outputRequired Is the output required to exist in the data file. 166 * @param separator The separator character in the data file. 167 * @param quote The quote character in the data file. 168 */ 169 private CSVDataSource(URI dataFile, Path dataPath, RowProcessor<T> rowProcessor, boolean outputRequired, char separator, char quote) { 170 super(rowProcessor.getResponseProcessor().getOutputFactory(), rowProcessor, outputRequired); 171 this.dataPath = dataPath; 172 this.dataFile = dataFile; 173 this.separator = separator; 174 this.quote = quote; 175 this.provenance = new CSVDataSourceProvenance(this); 176 } 177 178 /** 179 * Used by the OLCUT configuration system, and should not be called by external code. 180 */ 181 @Override 182 public void postConfig() { 183 this.dataFile = dataPath.toUri(); 184 this.provenance = new CSVDataSourceProvenance(this); 185 } 186 187 @Override 188 public String toString() { 189 return "CSVDataSource(file=" + dataFile + ",rowProcessor="+rowProcessor.getDescription()+")"; 190 } 191 192 @Override 193 public ColumnarIterator rowIterator() { 194 try { 195 return new CSVIterator(dataFile, separator, quote); 196 } catch (IOException e) { 197 throw new IllegalStateException("Failed to read data",e); 198 } 199 } 200 201 @Override 202 public ConfiguredDataSourceProvenance getProvenance() { 203 return provenance; 204 } 205 206 /** 207 * Removes a UTF-8 byte order mark if it exists. 208 * <p> 209 * Note Tribuo only supports UTF-8 inputs, so the other BOMs are not checked for. 210 * @param stream The stream to check. 211 * @return An input stream with the BOM consumed (if present). 212 * @throws IOException If the stream failed to read. 213 */ 214 static InputStream removeBOM(InputStream stream) throws IOException { 215 PushbackInputStream pushbackStream = new PushbackInputStream(stream,3); 216 byte[] bomBytes = new byte[3]; 217 int bytesRead = pushbackStream.read(bomBytes,0,3); 218 if (!((bomBytes[0] == (byte)0xEF) && (bomBytes[1] == (byte)0xBB) && (bomBytes[2] == (byte)0xBF))) { 219 pushbackStream.unread(bomBytes); 220 } 221 return pushbackStream; 222 } 223 224 /** 225 * Removes a UTF-8 byte order mark if it exists. 226 * <p> 227 * Note Tribuo only supports UTF-8 inputs, so the other BOMs are not checked for. 228 * @param reader The reader to check. 229 * @return A reader with the BOM consumed (if present). 230 * @throws IOException If the reader failed to read. 231 */ 232 static Reader removeBOM(Reader reader) throws IOException { 233 PushbackReader pushbackStream = new PushbackReader(reader,1); 234 int bomChar = pushbackStream.read(); 235 if (!(bomChar == 0xFEFF)) { 236 pushbackStream.unread(bomChar); 237 } 238 return pushbackStream; 239 } 240 241 /** 242 * Provenance for {@link CSVDataSource}. 243 */ 244 public static class CSVDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 245 private static final long serialVersionUID = 1L; 246 247 private final DateTimeProvenance fileModifiedTime; 248 private final DateTimeProvenance dataSourceCreationTime; 249 private final HashProvenance sha256Hash; 250 251 <T extends Output<T>> CSVDataSourceProvenance(CSVDataSource<T> host) { 252 super(host,"DataSource"); 253 this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.dataPath.toFile().lastModified()), ZoneId.systemDefault())); 254 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now()); 255 this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.dataPath)); 256 } 257 258 public CSVDataSourceProvenance(Map<String,Provenance> map) { 259 this(extractProvenanceInfo(map)); 260 } 261 262 private CSVDataSourceProvenance(ExtractedInfo info) { 263 super(info); 264 this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME); 265 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 266 this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH); 267 } 268 269 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 270 Map<String,Provenance> configuredParameters = new HashMap<>(map); 271 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, CSVDataSourceProvenance.class.getSimpleName()).getValue(); 272 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, CSVDataSourceProvenance.class.getSimpleName()).getValue(); 273 274 Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 275 instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, CSVDataSourceProvenance.class.getSimpleName())); 276 instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, CSVDataSourceProvenance.class.getSimpleName())); 277 instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, CSVDataSourceProvenance.class.getSimpleName())); 278 279 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters); 280 } 281 282 @Override 283 public boolean equals(Object o) { 284 if (this == o) return true; 285 if (o == null || getClass() != o.getClass()) return false; 286 if (!super.equals(o)) return false; 287 CSVDataSourceProvenance pairs = (CSVDataSourceProvenance) o; 288 return fileModifiedTime.equals(pairs.fileModifiedTime) && 289 dataSourceCreationTime.equals(pairs.dataSourceCreationTime) && 290 sha256Hash.equals(pairs.sha256Hash); 291 } 292 293 @Override 294 public int hashCode() { 295 return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash); 296 } 297 298 @Override 299 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 300 Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues(); 301 302 map.put(FILE_MODIFIED_TIME,fileModifiedTime); 303 map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime); 304 map.put(RESOURCE_HASH,sha256Hash); 305 306 return map; 307 } 308 } 309}