001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.csv; 018 019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance; 027import com.oracle.labs.mlrg.olcut.util.Pair; 028import org.tribuo.Example; 029import org.tribuo.MutableDataset; 030import org.tribuo.Output; 031import org.tribuo.OutputFactory; 032import org.tribuo.datasource.ListDataSource; 033import org.tribuo.impl.ArrayExample; 034import org.tribuo.provenance.DataSourceProvenance; 035import org.tribuo.provenance.OutputFactoryProvenance; 036 037import java.io.BufferedReader; 038import java.io.FileInputStream; 039import java.io.FileNotFoundException; 040import java.io.IOException; 041import java.io.InputStreamReader; 042import java.io.Reader; 043import java.net.URISyntaxException; 044import java.net.URL; 045import java.nio.charset.StandardCharsets; 046import java.nio.file.Files; 047import java.nio.file.Path; 048import java.time.OffsetDateTime; 049import java.util.ArrayList; 050import java.util.Arrays; 051import java.util.Collections; 052import java.util.HashMap; 053import java.util.HashSet; 054import java.util.Iterator; 055import java.util.List; 056import java.util.Map; 057import java.util.Objects; 058import java.util.Optional; 059import java.util.Set; 060import java.util.logging.Logger; 061 062/** 063 * Load a DataSource/Dataset from a CSV file. 064 * <p> 065 * The delimiter and quote characters are user controlled, so this class can parse TSVs, 066 * CSVs, semi-colon separated data and other types of single character delimiter separated data. 067 * <p> 068 * This class is a simple loader *only* for numerical CSV files with a String response field. 069 * If you need more complex processing, the response field isn't present, or you don't wish to 070 * use all of the columns as features then you should use {@link CSVDataSource} and build a 071 * {@link org.tribuo.data.columnar.RowProcessor} to cope with your specific input format. 072 * <p> 073 * CSVLoader is thread safe and immutable. 074 * @param <T> The type of the output generated. 075 */ 076public class CSVLoader<T extends Output<T>> { 077 078 private static final Logger logger = Logger.getLogger(CSVLoader.class.getName()); 079 080 private final char separator; 081 private final char quote; 082 private final OutputFactory<T> outputFactory; 083 084 /** 085 * Creates a CSVLoader using the supplied separator, quote and output factory. 086 * @param separator The separator character. 087 * @param quote The quote character. 088 * @param outputFactory The output factory. 089 */ 090 public CSVLoader(char separator, char quote, OutputFactory<T> outputFactory) { 091 this.separator = separator; 092 this.quote = quote; 093 this.outputFactory = outputFactory; 094 } 095 096 /** 097 * Creates a CSVLoader using the supplied separator and output factory. 098 * Sets the quote to {@link CSVIterator#QUOTE}. 099 * @param separator The separator character. 100 * @param outputFactory The output factory. 101 */ 102 public CSVLoader(char separator, OutputFactory<T> outputFactory) { 103 this(separator, CSVIterator.QUOTE, outputFactory); 104 } 105 106 /** 107 * Creates a CSVLoader using the supplied output factory. 108 * Sets the separator to {@link CSVIterator#SEPARATOR} and the quote to {@link CSVIterator#QUOTE}. 109 * @param outputFactory The output factory. 110 */ 111 public CSVLoader(OutputFactory<T> outputFactory) { 112 this(CSVIterator.SEPARATOR, CSVIterator.QUOTE, outputFactory); 113 } 114 115 /** 116 * Loads a DataSource from the specified csv file then wraps it in a dataset. 117 * 118 * @param csvPath The path to load. 119 * @param responseName The name of the response variable. 120 * @return A dataset containing the csv data. 121 * @throws IOException If the read failed. 122 */ 123 public MutableDataset<T> load(Path csvPath, String responseName) throws IOException { 124 return new MutableDataset<>(loadDataSource(csvPath, responseName)); 125 } 126 127 /** 128 * Loads a DataSource from the specified csv file then wraps it in a dataset. 129 * 130 * @param csvPath The path to load. 131 * @param responseName The name of the response variable. 132 * @param header The header of the CSV if it's not present in the file. 133 * @return A dataset containing the csv data. 134 * @throws IOException If the read failed. 135 */ 136 public MutableDataset<T> load(Path csvPath, String responseName, String[] header) throws IOException { 137 return new MutableDataset<>(loadDataSource(csvPath, responseName, header)); 138 } 139 140 /** 141 * Loads a DataSource from the specified csv file then wraps it in a dataset. 142 * 143 * @param csvPath The path to load. 144 * @param responseNames The names of the response variables. 145 * @return A dataset containing the csv data. 146 * @throws IOException If the read failed. 147 */ 148 public MutableDataset<T> load(Path csvPath, Set<String> responseNames) throws IOException { 149 return new MutableDataset<>(loadDataSource(csvPath, responseNames)); 150 } 151 152 /** 153 * Loads a DataSource from the specified csv file then wraps it in a dataset. 154 * 155 * @param csvPath The path to load. 156 * @param responseNames The names of the response variables. 157 * @param header The header of the CSV if it's not present in the file. 158 * @return A dataset containing the csv data. 159 * @throws IOException If the read failed. 160 */ 161 public MutableDataset<T> load(Path csvPath, Set<String> responseNames, String[] header) throws IOException { 162 return new MutableDataset<>(loadDataSource(csvPath, responseNames, header)); 163 } 164 165 /** 166 * Loads a DataSource from the specified csv path. 167 * 168 * @param csvPath The csv to load from. 169 * @param responseName The name of the response variable. 170 * @return A datasource containing the csv data. 171 * @throws IOException If the disk read failed. 172 */ 173 public ListDataSource<T> loadDataSource(Path csvPath, String responseName) throws IOException { 174 return loadDataSource(csvPath, Collections.singleton(responseName)); 175 } 176 177 /** 178 * Loads a DataSource from the specified csv path. 179 * 180 * @param csvPath The csv to load from. 181 * @param responseName The name of the response variable. 182 * @return A datasource containing the csv data. 183 * @throws IOException If the disk read failed. 184 */ 185 public ListDataSource<T> loadDataSource(URL csvPath, String responseName) throws IOException { 186 return loadDataSource(csvPath, Collections.singleton(responseName)); 187 } 188 189 /** 190 * Loads a DataSource from the specified csv path. 191 * 192 * @param csvPath The csv to load from. 193 * @param responseName The name of the response variable. 194 * @param header The header of the CSV if it's not present in the file. 195 * @return A datasource containing the csv data. 196 * @throws IOException If the disk read failed. 197 */ 198 public ListDataSource<T> loadDataSource(Path csvPath, String responseName, String[] header) throws IOException { 199 return loadDataSource(csvPath, Collections.singleton(responseName), header); 200 } 201 202 /** 203 * Loads a DataSource from the specified csv path. 204 * 205 * @param csvPath The csv to load from. 206 * @param responseName The name of the response variable. 207 * @param header The header of the CSV if it's not present in the file. 208 * @return A datasource containing the csv data. 209 * @throws IOException If the disk read failed. 210 */ 211 public ListDataSource<T> loadDataSource(URL csvPath, String responseName, String[] header) throws IOException { 212 return loadDataSource(csvPath, Collections.singleton(responseName), header); 213 } 214 215 /** 216 * Loads a DataSource from the specified csv path. 217 * 218 * @param csvPath The csv to load from. 219 * @param responseNames The names of the response variables. 220 * @return A datasource containing the csv data. 221 * @throws IOException If the disk read failed. 222 */ 223 public ListDataSource<T> loadDataSource(Path csvPath, Set<String> responseNames) throws IOException { 224 return loadDataSource(csvPath, responseNames, null); 225 } 226 227 /** 228 * Loads a DataSource from the specified csv path. 229 * 230 * @param csvPath The csv to load from. 231 * @param responseNames The names of the response variables. 232 * @return A datasource containing the csv data. 233 * @throws IOException If the disk read failed. 234 */ 235 public ListDataSource<T> loadDataSource(URL csvPath, Set<String> responseNames) throws IOException { 236 return loadDataSource(csvPath, responseNames, null); 237 } 238 239 /** 240 * Loads a DataSource from the specified csv path. 241 * 242 * @param csvPath The csv to load from. 243 * @param responseNames The names of the response variables. 244 * @param header The header of the CSV if it's not present in the file. 245 * @return A datasource containing the csv data. 246 * @throws IOException If the disk read failed. 247 */ 248 public ListDataSource<T> loadDataSource(Path csvPath, Set<String> responseNames, String[] header) throws IOException { 249 return loadDataSource(csvPath.toUri().toURL(),responseNames,header); 250 } 251 252 /** 253 * Loads a DataSource from the specified csv path. 254 * 255 * @param csvPath The csv to load from. 256 * @param responseNames The names of the response variables. 257 * @param header The header of the CSV if it's not present in the file. 258 * @return A datasource containing the csv data. 259 * @throws IOException If the disk read failed. 260 */ 261 public ListDataSource<T> loadDataSource(URL csvPath, Set<String> responseNames, String[] header) throws IOException { 262 List<String> headerList = header == null ? Collections.emptyList() : Arrays.asList(header); 263 try (CSVIterator itr = new CSVIterator(csvPath.toURI(), separator, quote, headerList)) { 264 // 265 // CSVInteropProvenance constructor throws an exception on FileNotFound, so we include in the try/catch 266 DataSourceProvenance provenance = new CSVLoaderProvenance( 267 csvPath, 268 outputFactory.getProvenance(), 269 String.join(",", responseNames), // If there are multiple responses, join them 270 separator, 271 quote 272 ); 273 List<Example<T>> list = innerLoadFromCSV(itr, responseNames, csvPath.toString()); 274 return new ListDataSource<>(list, outputFactory, provenance); 275 } catch (URISyntaxException e) { 276 throw new FileNotFoundException("Failed to read from URL '" + csvPath + "' as it could not be converted to a URI"); 277 } 278 } 279 280 private List<Example<T>> innerLoadFromCSV(CSVIterator itr, Set<String> responseNames, String csvPath) { 281 validateResponseNames(responseNames, itr.getFields(), csvPath); 282 List<Example<T>> dataset = new ArrayList<>(); 283 String responseName = responseNames.size() == 1 ? responseNames.iterator().next() : null; 284 // 285 // Create the examples. 286 while (itr.hasNext()) { 287 Map<String, String> row = itr.next().getRowData(); 288 T label = (responseNames.size() == 1) ? 289 buildOutput(responseName, row) : 290 buildMultiOutput(responseNames, row); 291 ArrayExample<T> example = new ArrayExample<>(label); 292 for (Map.Entry<String, String> ent : row.entrySet()) { 293 String columnName = ent.getKey(); 294 if (!responseNames.contains(columnName)) { 295 // 296 // If it's not a response, it's a feature 297 double value = Double.parseDouble(ent.getValue()); 298 example.add(columnName, value); 299 } 300 } 301 dataset.add(example); 302 } 303 return dataset; 304 } 305 306 private static void validateResponseNames(Set<String> responseNames, List<String> headers, String csvPath) throws IllegalArgumentException { 307 if (responseNames.isEmpty()) { 308 throw new IllegalArgumentException("At least one response name must be specified, but responseNames is empty."); 309 } 310 // 311 // Validate that all the expected responses are included in the given header fields 312 Map<String, Boolean> responsesFound = new HashMap<>(); 313 for (String response : responseNames) { 314 responsesFound.put(response, false); 315 } 316 for (String header : headers) { 317 if (responseNames.contains(header)) { 318 responsesFound.put(header, true); 319 } 320 } 321 for (Map.Entry<String, Boolean> kv : responsesFound.entrySet()) { 322 if (!kv.getValue()) { 323 throw new IllegalArgumentException(String.format("Response %s not found in file %s", kv.getKey(), csvPath)); 324 } 325 } 326 } 327 328 private T buildOutput(String responseName, Map<String, String> row) { 329 String label = row.get(responseName); 330 T output = outputFactory.generateOutput(label); 331 return output; 332 } 333 334 /** 335 * Build a Output for a multi-output CSV file formatted like: 336 * <pre> 337 * Attr1,Attr2,...,Class1,Class2,Class3 338 * 1.0,0.5,...,true,true,false 339 * 1.0,0.5,...,true,false,false 340 * 1.0,0.5,...,false,true,true 341 * </pre> 342 * Or for multivariate regression, 343 * <pre> 344 * Attr1,Attr2,...,Var1,Var2,Var3 345 * 1.0,0.5,...,0.1,0.1,0.3 346 * 1.0,0.5,...,0.2,0.0,0.8 347 * </pre> 348 * @param responseNames The response dimension names 349 * @param row The row to process. 350 */ 351 private T buildMultiOutput(Set<String> responseNames, Map<String, String> row) { 352 Set<String> pairs = new HashSet<>(); 353 for (String responseName : responseNames) { 354 String rawValue = row.get(responseName); 355 String pair = String.format("%s=%s", responseName, rawValue); 356 pairs.add(pair); 357 } 358 T output = outputFactory.generateOutput(pairs); 359 return output; 360 } 361 362 /** 363 * Provenance for CSVs loaded by {@link CSVLoader}. 364 */ 365 public final static class CSVLoaderProvenance implements DataSourceProvenance { 366 private static final long serialVersionUID = 1L; 367 368 private static final String RESPONSE_NAME = "response-name"; 369 private static final String SEP_PROV = "separator"; 370 private static final String QUOTE_PROV = "quote"; 371 private static final String PATH = "path"; 372 373 private final StringProvenance className; 374 private final OutputFactoryProvenance factoryProvenance; 375 376 // In the multi-output case, the responseName will be a comma-separated list of response names 377 private final StringProvenance responseName; 378 private final CharProvenance separator; 379 private final CharProvenance quote; 380 private final URLProvenance path; 381 private final DateTimeProvenance fileModifiedTime; 382 private final HashProvenance sha256Hash; 383 384 CSVLoaderProvenance(URL path, OutputFactoryProvenance factoryProvenance, String responseName, char separator, char quote) { 385 this.className = new StringProvenance(CLASS_NAME, CSVLoader.class.getName()); 386 this.factoryProvenance = factoryProvenance; 387 this.responseName = new StringProvenance(RESPONSE_NAME, responseName); 388 this.separator = new CharProvenance(SEP_PROV, separator); 389 this.quote = new CharProvenance(QUOTE_PROV, quote); 390 this.path = new URLProvenance(PATH, path); 391 Optional<OffsetDateTime> time = ProvenanceUtil.getModifiedTime(path); 392 this.fileModifiedTime = time.map(offsetDateTime -> new DateTimeProvenance(FILE_MODIFIED_TIME, offsetDateTime)).orElseGet(() -> new DateTimeProvenance(FILE_MODIFIED_TIME, OffsetDateTime.MIN)); 393 this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE, RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, path)); 394 } 395 396 public CSVLoaderProvenance(Map<String, Provenance> map) { 397 this.className = ObjectProvenance.checkAndExtractProvenance(map, CLASS_NAME, StringProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 398 this.factoryProvenance = ObjectProvenance.checkAndExtractProvenance(map, OUTPUT_FACTORY, OutputFactoryProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 399 this.responseName = ObjectProvenance.checkAndExtractProvenance(map, RESPONSE_NAME, StringProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 400 this.separator = ObjectProvenance.checkAndExtractProvenance(map, SEP_PROV, CharProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 401 this.quote = ObjectProvenance.checkAndExtractProvenance(map, QUOTE_PROV, CharProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 402 this.path = ObjectProvenance.checkAndExtractProvenance(map, PATH, URLProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 403 this.fileModifiedTime = ObjectProvenance.checkAndExtractProvenance(map, FILE_MODIFIED_TIME, DateTimeProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 404 this.sha256Hash = ObjectProvenance.checkAndExtractProvenance(map, RESOURCE_HASH, HashProvenance.class, CSVLoaderProvenance.class.getSimpleName()); 405 } 406 407 @Override 408 public String getClassName() { 409 return className.getValue(); 410 } 411 412 @Override 413 public Iterator<Pair<String, Provenance>> iterator() { 414 ArrayList<Pair<String, Provenance>> list = new ArrayList<>(); 415 416 list.add(new Pair<>(CLASS_NAME, className)); 417 list.add(new Pair<>(OUTPUT_FACTORY, factoryProvenance)); 418 list.add(new Pair<>(RESPONSE_NAME, responseName)); 419 list.add(new Pair<>(SEP_PROV, separator)); 420 list.add(new Pair<>(QUOTE_PROV, quote)); 421 list.add(new Pair<>(PATH, path)); 422 list.add(new Pair<>(FILE_MODIFIED_TIME, fileModifiedTime)); 423 list.add(new Pair<>(RESOURCE_HASH, sha256Hash)); 424 425 return list.iterator(); 426 } 427 428 @Override 429 public boolean equals(Object o) { 430 if (this == o) return true; 431 if (!(o instanceof CSVLoaderProvenance)) return false; 432 CSVLoaderProvenance pairs = (CSVLoaderProvenance) o; 433 return className.equals(pairs.className) && 434 factoryProvenance.equals(pairs.factoryProvenance) && 435 responseName.equals(pairs.responseName) && 436 separator.equals(pairs.separator) && 437 quote.equals(pairs.quote) && 438 path.equals(pairs.path) && 439 fileModifiedTime.equals(pairs.fileModifiedTime) && 440 sha256Hash.equals(pairs.sha256Hash); 441 } 442 443 @Override 444 public int hashCode() { 445 return Objects.hash(className, factoryProvenance, responseName, separator, quote, path, fileModifiedTime, sha256Hash); 446 } 447 448 @Override 449 public String toString() { 450 return generateString("CSV"); 451 } 452 } 453 454}