001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Dataset; 023import org.tribuo.ImmutableDataset; 024import org.tribuo.Model; 025import org.tribuo.MutableDataset; 026import org.tribuo.Output; 027import org.tribuo.OutputFactory; 028import org.tribuo.Trainer; 029import org.tribuo.data.columnar.RowProcessor; 030import org.tribuo.data.csv.CSVDataSource; 031import org.tribuo.data.csv.CSVLoader; 032import org.tribuo.data.text.TextDataSource; 033import org.tribuo.data.text.TextFeatureExtractor; 034import org.tribuo.data.text.impl.SimpleTextDataSource; 035import org.tribuo.data.text.impl.TextFeatureExtractorImpl; 036import org.tribuo.data.text.impl.TokenPipeline; 037import org.tribuo.dataset.MinimumCardinalityDataset; 038import org.tribuo.datasource.LibSVMDataSource; 039import org.tribuo.util.tokens.impl.BreakIteratorTokenizer; 040 041import java.io.BufferedInputStream; 042import java.io.FileInputStream; 043import java.io.FileOutputStream; 044import java.io.IOException; 045import java.io.ObjectInputStream; 046import java.io.ObjectOutputStream; 047import java.nio.file.Path; 048import java.util.Locale; 049import java.util.logging.Logger; 050 051/** 052 * Options for working with training and test data in a CLI. 053 */ 054public final class DataOptions implements Options { 055 private static final Logger logger = Logger.getLogger(DataOptions.class.getName()); 056 057 /** 058 * The input formats supported by this options object. 059 */ 060 public enum InputFormat { 061 SERIALIZED, LIBSVM, TEXT, CSV, COLUMNAR 062 } 063 064 /** 065 * The delimiters supported by CSV files in this options object. 066 */ 067 public enum Delimiter { 068 COMMA(','), TAB('\t'), SEMICOLON(';'); 069 070 public final char value; 071 Delimiter(char value) { 072 this.value = value; 073 } 074 } 075 076 @Override 077 public String getOptionsDescription() { 078 return "Options for loading and processing train and test data."; 079 } 080 081 @Option(longName="hashing-dimension",usage="Hashing dimension used for standard text format.") 082 public int hashDim = 0; 083 @Option(longName="ngram",usage="Ngram size to generate when using standard text format.") 084 public int ngram = 2; 085 @Option(longName="term-counting",usage="Use term counts instead of boolean when using the standard text format.") 086 public boolean termCounting; 087 @Option(charName='f',longName="model-output-path",usage="Path to serialize model to.") 088 public Path outputPath; 089 @Option(charName='r',longName="seed",usage="RNG seed.") 090 public long seed = Trainer.DEFAULT_SEED; 091 @Option(charName='s',longName="input-format",usage="Loads the data using the specified format.") 092 public InputFormat inputFormat = InputFormat.LIBSVM; 093 @Option(longName="csv-response-name",usage="Response name in the csv file.") 094 public String csvResponseName; 095 @Option(longName="csv-delimiter",usage="Delimiter") 096 public Delimiter delimiter = Delimiter.COMMA; 097 @Option(longName="csv-quote-char",usage="Quote character in the CSV file.") 098 public char csvQuoteChar = '"'; 099 @Option(longName="columnar-row-processor",usage="The name of the row processor from the config file.") 100 public RowProcessor<?> rowProcessor; 101 @Option(longName="min-count",usage="Minimum cardinality of the features.") 102 public int minCount = 0; 103 @Option(charName='u',longName="training-file",usage="Path to the training file.") 104 public Path trainingPath; 105 @Option(charName='v',longName="testing-file",usage="Path to the testing file.") 106 public Path testingPath; 107 108 public <T extends Output<T>> Pair<Dataset<T>,Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException { 109 logger.info(String.format("Loading data from %s", trainingPath)); 110 Dataset<T> train; 111 Dataset<T> test; 112 char separator; 113 switch (inputFormat) { 114 case SERIALIZED: 115 // 116 // Load Tribuo serialised datasets. 117 logger.info("Deserialising dataset from " + trainingPath); 118 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile()))); 119 ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) { 120 @SuppressWarnings("unchecked") 121 Dataset<T> tmp = (Dataset<T>) ois.readObject(); 122 train = tmp; 123 if (minCount > 0) { 124 logger.info("Found " + train.getFeatureIDMap().size() + " features"); 125 logger.info("Removing features that occur fewer than " + minCount + " times."); 126 train = new MinimumCardinalityDataset<>(train,minCount); 127 } 128 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 129 logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions"); 130 @SuppressWarnings("unchecked") 131 Dataset<T> deserTest = (Dataset<T>) oits.readObject(); 132 test = new ImmutableDataset<>(deserTest,deserTest.getSourceProvenance(),deserTest.getOutputFactory(),train.getFeatureIDMap(),train.getOutputIDInfo(),true); 133 } catch (ClassNotFoundException e) { 134 throw new IllegalArgumentException("Unknown class in serialised files", e); 135 } 136 break; 137 case LIBSVM: 138 // 139 // Load the libsvm text-based data format. 140 LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath,outputFactory); 141 train = new MutableDataset<>(trainSVMSource); 142 boolean zeroIndexed = trainSVMSource.isZeroIndexed(); 143 int maxFeatureID = trainSVMSource.getMaxFeatureID(); 144 if (minCount > 0) { 145 logger.info("Removing features that occur fewer than " + minCount + " times."); 146 train = new MinimumCardinalityDataset<>(train,minCount); 147 } 148 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 149 logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions"); 150 test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath,outputFactory,zeroIndexed,maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false); 151 break; 152 case TEXT: 153 // 154 // Using a simple Java break iterator to generate ngram features. 155 TextFeatureExtractor<T> extractor; 156 if (hashDim > 0) { 157 extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim)); 158 } else { 159 extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting)); 160 } 161 162 TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor); 163 train = new MutableDataset<>(trainSource); 164 if (minCount > 0) { 165 logger.info("Removing features that occur fewer than " + minCount + " times."); 166 train = new MinimumCardinalityDataset<>(train,minCount); 167 } 168 169 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 170 logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions"); 171 172 TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor); 173 test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false); 174 break; 175 case CSV: 176 // 177 // Load the data using the simple CSV loader 178 if (csvResponseName == null) { 179 throw new IllegalArgumentException("Please supply a response column name"); 180 } 181 separator = delimiter.value; 182 CSVLoader<T> loader = new CSVLoader<>(separator,outputFactory); 183 train = new MutableDataset<>(loader.loadDataSource(trainingPath,csvResponseName)); 184 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 185 logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions"); 186 test = new MutableDataset<>(loader.loadDataSource(testingPath,csvResponseName)); 187 break; 188 case COLUMNAR: 189 if (rowProcessor == null) { 190 throw new IllegalArgumentException("Please supply a RowProcessor"); 191 } 192 OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory(); 193 if (!rowOutputFactory.equals(outputFactory)) { 194 throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName()); 195 } 196 @SuppressWarnings("unchecked") // checked by the if statement above 197 RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor; 198 separator = delimiter.value; 199 train = new MutableDataset<>(new CSVDataSource<>(trainingPath,typedRowProcessor,true,separator,csvQuoteChar)); 200 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 201 logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions"); 202 test = new MutableDataset<>(new CSVDataSource<>(testingPath,typedRowProcessor,true,separator,csvQuoteChar)); 203 break; 204 default: 205 throw new IllegalArgumentException("Unsupported input format " + inputFormat); 206 } 207 logger.info(String.format("Loaded %d testing examples", test.size())); 208 return new Pair<>(train,test); 209 } 210 211 public <T extends Output<T>> void saveModel(Model<T> model) throws IOException { 212 FileOutputStream fout = new FileOutputStream(outputPath.toFile()); 213 ObjectOutputStream oout = new ObjectOutputStream(fout); 214 oout.writeObject(model); 215 oout.close(); 216 fout.close(); 217 logger.info("Serialized model to file: " + outputPath); 218 } 219}