001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Dataset;
023import org.tribuo.ImmutableDataset;
024import org.tribuo.Model;
025import org.tribuo.MutableDataset;
026import org.tribuo.Output;
027import org.tribuo.OutputFactory;
028import org.tribuo.Trainer;
029import org.tribuo.data.columnar.RowProcessor;
030import org.tribuo.data.csv.CSVDataSource;
031import org.tribuo.data.csv.CSVLoader;
032import org.tribuo.data.text.TextDataSource;
033import org.tribuo.data.text.TextFeatureExtractor;
034import org.tribuo.data.text.impl.SimpleTextDataSource;
035import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
036import org.tribuo.data.text.impl.TokenPipeline;
037import org.tribuo.dataset.MinimumCardinalityDataset;
038import org.tribuo.datasource.LibSVMDataSource;
039import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;
040
041import java.io.BufferedInputStream;
042import java.io.FileInputStream;
043import java.io.FileOutputStream;
044import java.io.IOException;
045import java.io.ObjectInputStream;
046import java.io.ObjectOutputStream;
047import java.nio.file.Path;
048import java.util.Locale;
049import java.util.logging.Logger;
050
051/**
052 * Options for working with training and test data in a CLI.
053 */
054public final class DataOptions implements Options {
055    private static final Logger logger = Logger.getLogger(DataOptions.class.getName());
056
057    /**
058     * The input formats supported by this options object.
059     */
060    public enum InputFormat {
061        SERIALIZED, LIBSVM, TEXT, CSV, COLUMNAR
062    }
063
064    /**
065     * The delimiters supported by CSV files in this options object.
066     */
067    public enum Delimiter {
068        COMMA(','), TAB('\t'), SEMICOLON(';');
069
070        public final char value;
071        Delimiter(char value) {
072            this.value = value;
073        }
074    }
075
076    @Override
077    public String getOptionsDescription() {
078        return "Options for loading and processing train and test data.";
079    }
080
081    @Option(longName="hashing-dimension",usage="Hashing dimension used for standard text format.")
082    public int hashDim = 0;
083    @Option(longName="ngram",usage="Ngram size to generate when using standard text format.")
084    public int ngram = 2;
085    @Option(longName="term-counting",usage="Use term counts instead of boolean when using the standard text format.")
086    public boolean termCounting;
087    @Option(charName='f',longName="model-output-path",usage="Path to serialize model to.")
088    public Path outputPath;
089    @Option(charName='r',longName="seed",usage="RNG seed.")
090    public long seed = Trainer.DEFAULT_SEED;
091    @Option(charName='s',longName="input-format",usage="Loads the data using the specified format.")
092    public InputFormat inputFormat = InputFormat.LIBSVM;
093    @Option(longName="csv-response-name",usage="Response name in the csv file.")
094    public String csvResponseName;
095    @Option(longName="csv-delimiter",usage="Delimiter")
096    public Delimiter delimiter = Delimiter.COMMA;
097    @Option(longName="csv-quote-char",usage="Quote character in the CSV file.")
098    public char csvQuoteChar = '"';
099    @Option(longName="columnar-row-processor",usage="The name of the row processor from the config file.")
100    public RowProcessor<?> rowProcessor;
101    @Option(longName="min-count",usage="Minimum cardinality of the features.")
102    public int minCount = 0;
103    @Option(charName='u',longName="training-file",usage="Path to the training file.")
104    public Path trainingPath;
105    @Option(charName='v',longName="testing-file",usage="Path to the testing file.")
106    public Path testingPath;
107
108    public <T extends Output<T>> Pair<Dataset<T>,Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
109        logger.info(String.format("Loading data from %s", trainingPath));
110        Dataset<T> train;
111        Dataset<T> test;
112        char separator;
113        switch (inputFormat) {
114            case SERIALIZED:
115                //
116                // Load Tribuo serialised datasets.
117                logger.info("Deserialising dataset from " + trainingPath);
118                try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
119                     ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
120                    @SuppressWarnings("unchecked")
121                    Dataset<T> tmp = (Dataset<T>) ois.readObject();
122                    train = tmp;
123                    if (minCount > 0) {
124                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
125                        logger.info("Removing features that occur fewer than " + minCount + " times.");
126                        train = new MinimumCardinalityDataset<>(train,minCount);
127                    }
128                    logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
129                    logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
130                    @SuppressWarnings("unchecked")
131                    Dataset<T> deserTest = (Dataset<T>) oits.readObject();
132                    test = new ImmutableDataset<>(deserTest,deserTest.getSourceProvenance(),deserTest.getOutputFactory(),train.getFeatureIDMap(),train.getOutputIDInfo(),true);
133                } catch (ClassNotFoundException e) {
134                    throw new IllegalArgumentException("Unknown class in serialised files", e);
135                }
136                break;
137            case LIBSVM:
138                //
139                // Load the libsvm text-based data format.
140                LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath,outputFactory);
141                train = new MutableDataset<>(trainSVMSource);
142                boolean zeroIndexed = trainSVMSource.isZeroIndexed();
143                int maxFeatureID = trainSVMSource.getMaxFeatureID();
144                if (minCount > 0) {
145                    logger.info("Removing features that occur fewer than " + minCount + " times.");
146                    train = new MinimumCardinalityDataset<>(train,minCount);
147                }
148                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
149                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
150                test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath,outputFactory,zeroIndexed,maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
151                break;
152            case TEXT:
153                //
154                // Using a simple Java break iterator to generate ngram features.
155                TextFeatureExtractor<T> extractor;
156                if (hashDim > 0) {
157                    extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
158                } else {
159                    extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
160                }
161
162                TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
163                train = new MutableDataset<>(trainSource);
164                if (minCount > 0) {
165                    logger.info("Removing features that occur fewer than " + minCount + " times.");
166                    train = new MinimumCardinalityDataset<>(train,minCount);
167                }
168
169                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
170                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
171
172                TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
173                test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
174                break;
175            case CSV:
176                //
177                // Load the data using the simple CSV loader
178                if (csvResponseName == null) {
179                    throw new IllegalArgumentException("Please supply a response column name");
180                }
181                separator = delimiter.value;
182                CSVLoader<T> loader = new CSVLoader<>(separator,outputFactory);
183                train = new MutableDataset<>(loader.loadDataSource(trainingPath,csvResponseName));
184                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
185                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
186                test = new MutableDataset<>(loader.loadDataSource(testingPath,csvResponseName));
187                break;
188            case COLUMNAR:
189                if (rowProcessor == null) {
190                    throw new IllegalArgumentException("Please supply a RowProcessor");
191                }
192                OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
193                if (!rowOutputFactory.equals(outputFactory)) {
194                    throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
195                }
196                @SuppressWarnings("unchecked") // checked by the if statement above
197                RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
198                separator = delimiter.value;
199                train = new MutableDataset<>(new CSVDataSource<>(trainingPath,typedRowProcessor,true,separator,csvQuoteChar));
200                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
201                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
202                test = new MutableDataset<>(new CSVDataSource<>(testingPath,typedRowProcessor,true,separator,csvQuoteChar));
203                break;
204            default:
205                throw new IllegalArgumentException("Unsupported input format " + inputFormat);
206        }
207        logger.info(String.format("Loaded %d testing examples", test.size()));
208        return new Pair<>(train,test);
209    }
210
211    public <T extends Output<T>> void saveModel(Model<T> model) throws IOException {
212        FileOutputStream fout = new FileOutputStream(outputPath.toFile());
213        ObjectOutputStream oout = new ObjectOutputStream(fout);
214        oout.writeObject(model);
215        oout.close();
216        fout.close();
217        logger.info("Serialized model to file: " + outputPath);
218    }
219}