001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.experiments;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Example;
027import org.tribuo.ImmutableDataset;
028import org.tribuo.Model;
029import org.tribuo.Prediction;
030import org.tribuo.classification.Label;
031import org.tribuo.classification.LabelFactory;
032import org.tribuo.classification.evaluation.LabelEvaluation;
033import org.tribuo.classification.evaluation.LabelEvaluator;
034import org.tribuo.data.DataOptions;
035import org.tribuo.data.csv.CSVLoader;
036import org.tribuo.data.text.TextDataSource;
037import org.tribuo.data.text.TextFeatureExtractor;
038import org.tribuo.data.text.impl.SimpleTextDataSource;
039import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
040import org.tribuo.data.text.impl.TokenPipeline;
041import org.tribuo.datasource.LibSVMDataSource;
042import org.tribuo.util.Util;
043import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;
044
045import java.io.BufferedInputStream;
046import java.io.BufferedWriter;
047import java.io.FileInputStream;
048import java.io.IOException;
049import java.io.ObjectInputStream;
050import java.nio.file.Files;
051import java.nio.file.Path;
052import java.util.List;
053import java.util.Locale;
054import java.util.logging.Level;
055import java.util.logging.Logger;
056import java.util.stream.Collectors;
057
058/**
059 * Test a classifier for a standard dataset.
060 */
061public class Test {
062
063    private static final Logger logger = Logger.getLogger(Test.class.getName());
064
065    public static class ConfigurableTestOptions implements Options {
066        @Override
067        public String getOptionsDescription() {
068            return "Tests an already trained classifier on a dataset.";
069        }
070        @Option(longName="hashing-dimension",usage="Hashing dimension used for standard text format.")
071        public int hashDim = 0;
072        @Option(longName="ngram",usage="Ngram size to generate when using standard text format. Defaults to 2.")
073        public int ngram = 2;
074        @Option(longName="term-counting",usage="Use term counts instead of boolean when using the standard text format.")
075        public boolean termCounting;
076        @Option(longName="csv-response-name",usage="Response name in the csv file.")
077        public String csvResponseName;
078        @Option(longName="libsvm-zero-indexed",usage="Is the libsvm file zero indexed.")
079        public boolean zeroIndexed = false;
080        @Option(charName='f',longName="model-path",usage="Load a trainer from the config file.")
081        public Path modelPath;
082        @Option(charName='o',longName="predictions",usage="Path to write model predictions")
083        public Path predictionPath;
084        @Option(charName='s',longName="input-format",usage="Loads the data using the specified format. Defaults to LIBSVM.")
085        public DataOptions.InputFormat inputFormat = DataOptions.InputFormat.LIBSVM;
086        @Option(charName='v',longName="testing-file",usage="Path to the testing file.")
087        public Path testingPath;
088    }
089
090    @SuppressWarnings("unchecked") // deserialising generically typed datasets.
091    public static Pair<Model<Label>,Dataset<Label>> load(ConfigurableTestOptions o) throws IOException {
092        Path modelPath = o.modelPath;
093        Path datasetPath = o.testingPath;
094        logger.info(String.format("Loading model from %s", modelPath));
095        Model<Label> model;
096        try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) {
097            model = (Model<Label>) mois.readObject();
098            boolean valid = model.validate(Label.class);
099            if (!valid) {
100                throw new ClassCastException("Failed to cast deserialised Model to Model<Label>");
101            }
102        } catch (ClassNotFoundException e) {
103            throw new IllegalArgumentException("Unknown class in serialised model", e);
104        }
105        logger.info(String.format("Loading data from %s", datasetPath));
106        Dataset<Label> test;
107        switch (o.inputFormat) {
108            case SERIALIZED:
109                //
110                // Load Tribuo serialised datasets.
111                logger.info("Deserialising dataset from " + datasetPath);
112                try (ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(datasetPath.toFile())))) {
113                    Dataset<Label> deserTest = (Dataset<Label>) oits.readObject();
114                    test = ImmutableDataset.copyDataset(deserTest,model.getFeatureIDMap(),model.getOutputIDInfo());
115                    logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
116                } catch (ClassNotFoundException e) {
117                    throw new IllegalArgumentException("Unknown class in serialised dataset", e);
118                }
119                break;
120            case LIBSVM:
121                //
122                // Load the libsvm text-based data format.
123                boolean zeroIndexed = o.zeroIndexed;
124                int maxFeatureID = model.getFeatureIDMap().size() - 1;
125                LibSVMDataSource<Label> testSVMSource = new LibSVMDataSource<>(datasetPath,new LabelFactory(),zeroIndexed,maxFeatureID);
126                test = new ImmutableDataset<>(testSVMSource,model,true);
127                logger.info(String.format("Loaded %d training examples for %s", test.size(), test.getOutputs().toString()));
128                break;
129            case TEXT:
130                //
131                // Using a simple Java break iterator to generate ngram features.
132                TextFeatureExtractor<Label> extractor;
133                if (o.hashDim > 0) {
134                    extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting, o.hashDim));
135                } else {
136                    extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting));
137                }
138
139                TextDataSource<Label> testSource = new SimpleTextDataSource<>(datasetPath, new LabelFactory(), extractor);
140                test = new ImmutableDataset<>(testSource, model.getFeatureIDMap(), model.getOutputIDInfo(),true);
141                logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
142                break;
143            case CSV:
144                //
145                // Load the data using the simple CSV loader
146                if (o.csvResponseName == null) {
147                    throw new IllegalArgumentException("Please supply a response column name");
148                }
149                CSVLoader<Label> loader = new CSVLoader<>(new LabelFactory());
150                test = new ImmutableDataset<>(loader.loadDataSource(datasetPath,o.csvResponseName),model.getFeatureIDMap(),model.getOutputIDInfo(),true);
151                logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
152                break;
153            default:
154                throw new IllegalArgumentException("Unsupported input format " + o.inputFormat);
155        }
156        return new Pair<>(model,test);
157    }
158
159    /**
160     * @param args the command line arguments
161     */
162    public static void main(String[] args) {
163
164        //
165        // Use the labs format logging.
166        LabsLogFormatter.setAllLogFormatters();
167
168        ConfigurableTestOptions o = new ConfigurableTestOptions();
169        ConfigurationManager cm;
170        try {
171            cm = new ConfigurationManager(args,o);
172        } catch (UsageException e) {
173            logger.info(e.getMessage());
174            return;
175        }
176
177        if (o.modelPath == null || o.testingPath == null) {
178            logger.info(cm.usage());
179            System.exit(1);
180        }
181        Pair<Model<Label>,Dataset<Label>> loaded = null;
182        try {
183             loaded = load(o);
184        } catch (IOException e) {
185            logger.log(Level.SEVERE, "Failed to load model/data", e);
186            System.exit(1);
187        }
188        Model<Label> model = loaded.getA();
189        Dataset<Label> test = loaded.getB();
190
191        logger.info("Model is " + model.toString());
192        logger.info("Labels are " + model.getOutputIDInfo().toReadableString());
193
194        LabelEvaluator labelEvaluator = new LabelEvaluator();
195        final long testStart = System.currentTimeMillis();
196        List<Prediction<Label>> predictions = model.predict(test);
197        LabelEvaluation evaluation = labelEvaluator.evaluate(model,predictions,test.getProvenance());
198        final long testStop = System.currentTimeMillis();
199        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
200        System.out.println(evaluation.toString());
201        System.out.println(evaluation.getConfusionMatrix().toString());
202        if (model.generatesProbabilities()) {
203            System.out.println("Average AUC = " + evaluation.averageAUCROC(false));
204            System.out.println("Average weighted AUC = " + evaluation.averageAUCROC(true));
205        }
206
207        if (o.predictionPath!=null) {
208            try(BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
209                List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
210                wrt.write("Label,");
211                wrt.write(String.join(",", labels));
212                wrt.newLine();
213                for(Prediction<Label> pred : predictions) {
214                    Example<Label> ex = pred.getExample();
215                    wrt.write(ex.getOutput().getLabel()+",");
216                    wrt.write(labels
217                            .stream()
218                            .map(l -> Double.toString(pred
219                                    .getOutputScores()
220                                    .get(l).getScore()))
221                            .collect(Collectors.joining(",")));
222                    wrt.newLine();
223                }
224                wrt.flush();
225            } catch (IOException e) {
226                logger.log(Level.SEVERE, "Error writing predictions", e);
227            }
228        }
229
230    }
231    
232}