001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.experiments; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Example; 027import org.tribuo.ImmutableDataset; 028import org.tribuo.Model; 029import org.tribuo.Prediction; 030import org.tribuo.classification.Label; 031import org.tribuo.classification.LabelFactory; 032import org.tribuo.classification.evaluation.LabelEvaluation; 033import org.tribuo.classification.evaluation.LabelEvaluator; 034import org.tribuo.data.DataOptions; 035import org.tribuo.data.csv.CSVLoader; 036import org.tribuo.data.text.TextDataSource; 037import org.tribuo.data.text.TextFeatureExtractor; 038import org.tribuo.data.text.impl.SimpleTextDataSource; 039import org.tribuo.data.text.impl.TextFeatureExtractorImpl; 040import org.tribuo.data.text.impl.TokenPipeline; 041import org.tribuo.datasource.LibSVMDataSource; 042import org.tribuo.util.Util; 043import org.tribuo.util.tokens.impl.BreakIteratorTokenizer; 044 045import java.io.BufferedInputStream; 046import java.io.BufferedWriter; 047import java.io.FileInputStream; 048import java.io.IOException; 049import java.io.ObjectInputStream; 050import java.nio.file.Files; 051import java.nio.file.Path; 052import java.util.List; 053import java.util.Locale; 054import java.util.logging.Level; 055import java.util.logging.Logger; 056import java.util.stream.Collectors; 057 058/** 059 * Test a classifier for a standard dataset. 060 */ 061public class Test { 062 063 private static final Logger logger = Logger.getLogger(Test.class.getName()); 064 065 public static class ConfigurableTestOptions implements Options { 066 @Override 067 public String getOptionsDescription() { 068 return "Tests an already trained classifier on a dataset."; 069 } 070 @Option(longName="hashing-dimension",usage="Hashing dimension used for standard text format.") 071 public int hashDim = 0; 072 @Option(longName="ngram",usage="Ngram size to generate when using standard text format. Defaults to 2.") 073 public int ngram = 2; 074 @Option(longName="term-counting",usage="Use term counts instead of boolean when using the standard text format.") 075 public boolean termCounting; 076 @Option(longName="csv-response-name",usage="Response name in the csv file.") 077 public String csvResponseName; 078 @Option(longName="libsvm-zero-indexed",usage="Is the libsvm file zero indexed.") 079 public boolean zeroIndexed = false; 080 @Option(charName='f',longName="model-path",usage="Load a trainer from the config file.") 081 public Path modelPath; 082 @Option(charName='o',longName="predictions",usage="Path to write model predictions") 083 public Path predictionPath; 084 @Option(charName='s',longName="input-format",usage="Loads the data using the specified format. Defaults to LIBSVM.") 085 public DataOptions.InputFormat inputFormat = DataOptions.InputFormat.LIBSVM; 086 @Option(charName='v',longName="testing-file",usage="Path to the testing file.") 087 public Path testingPath; 088 } 089 090 @SuppressWarnings("unchecked") // deserialising generically typed datasets. 091 public static Pair<Model<Label>,Dataset<Label>> load(ConfigurableTestOptions o) throws IOException { 092 Path modelPath = o.modelPath; 093 Path datasetPath = o.testingPath; 094 logger.info(String.format("Loading model from %s", modelPath)); 095 Model<Label> model; 096 try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) { 097 model = (Model<Label>) mois.readObject(); 098 boolean valid = model.validate(Label.class); 099 if (!valid) { 100 throw new ClassCastException("Failed to cast deserialised Model to Model<Label>"); 101 } 102 } catch (ClassNotFoundException e) { 103 throw new IllegalArgumentException("Unknown class in serialised model", e); 104 } 105 logger.info(String.format("Loading data from %s", datasetPath)); 106 Dataset<Label> test; 107 switch (o.inputFormat) { 108 case SERIALIZED: 109 // 110 // Load Tribuo serialised datasets. 111 logger.info("Deserialising dataset from " + datasetPath); 112 try (ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(datasetPath.toFile())))) { 113 Dataset<Label> deserTest = (Dataset<Label>) oits.readObject(); 114 test = ImmutableDataset.copyDataset(deserTest,model.getFeatureIDMap(),model.getOutputIDInfo()); 115 logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString())); 116 } catch (ClassNotFoundException e) { 117 throw new IllegalArgumentException("Unknown class in serialised dataset", e); 118 } 119 break; 120 case LIBSVM: 121 // 122 // Load the libsvm text-based data format. 123 boolean zeroIndexed = o.zeroIndexed; 124 int maxFeatureID = model.getFeatureIDMap().size() - 1; 125 LibSVMDataSource<Label> testSVMSource = new LibSVMDataSource<>(datasetPath,new LabelFactory(),zeroIndexed,maxFeatureID); 126 test = new ImmutableDataset<>(testSVMSource,model,true); 127 logger.info(String.format("Loaded %d training examples for %s", test.size(), test.getOutputs().toString())); 128 break; 129 case TEXT: 130 // 131 // Using a simple Java break iterator to generate ngram features. 132 TextFeatureExtractor<Label> extractor; 133 if (o.hashDim > 0) { 134 extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting, o.hashDim)); 135 } else { 136 extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting)); 137 } 138 139 TextDataSource<Label> testSource = new SimpleTextDataSource<>(datasetPath, new LabelFactory(), extractor); 140 test = new ImmutableDataset<>(testSource, model.getFeatureIDMap(), model.getOutputIDInfo(),true); 141 logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString())); 142 break; 143 case CSV: 144 // 145 // Load the data using the simple CSV loader 146 if (o.csvResponseName == null) { 147 throw new IllegalArgumentException("Please supply a response column name"); 148 } 149 CSVLoader<Label> loader = new CSVLoader<>(new LabelFactory()); 150 test = new ImmutableDataset<>(loader.loadDataSource(datasetPath,o.csvResponseName),model.getFeatureIDMap(),model.getOutputIDInfo(),true); 151 logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString())); 152 break; 153 default: 154 throw new IllegalArgumentException("Unsupported input format " + o.inputFormat); 155 } 156 return new Pair<>(model,test); 157 } 158 159 /** 160 * @param args the command line arguments 161 */ 162 public static void main(String[] args) { 163 164 // 165 // Use the labs format logging. 166 LabsLogFormatter.setAllLogFormatters(); 167 168 ConfigurableTestOptions o = new ConfigurableTestOptions(); 169 ConfigurationManager cm; 170 try { 171 cm = new ConfigurationManager(args,o); 172 } catch (UsageException e) { 173 logger.info(e.getMessage()); 174 return; 175 } 176 177 if (o.modelPath == null || o.testingPath == null) { 178 logger.info(cm.usage()); 179 System.exit(1); 180 } 181 Pair<Model<Label>,Dataset<Label>> loaded = null; 182 try { 183 loaded = load(o); 184 } catch (IOException e) { 185 logger.log(Level.SEVERE, "Failed to load model/data", e); 186 System.exit(1); 187 } 188 Model<Label> model = loaded.getA(); 189 Dataset<Label> test = loaded.getB(); 190 191 logger.info("Model is " + model.toString()); 192 logger.info("Labels are " + model.getOutputIDInfo().toReadableString()); 193 194 LabelEvaluator labelEvaluator = new LabelEvaluator(); 195 final long testStart = System.currentTimeMillis(); 196 List<Prediction<Label>> predictions = model.predict(test); 197 LabelEvaluation evaluation = labelEvaluator.evaluate(model,predictions,test.getProvenance()); 198 final long testStop = System.currentTimeMillis(); 199 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 200 System.out.println(evaluation.toString()); 201 System.out.println(evaluation.getConfusionMatrix().toString()); 202 if (model.generatesProbabilities()) { 203 System.out.println("Average AUC = " + evaluation.averageAUCROC(false)); 204 System.out.println("Average weighted AUC = " + evaluation.averageAUCROC(true)); 205 } 206 207 if (o.predictionPath!=null) { 208 try(BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) { 209 List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList()); 210 wrt.write("Label,"); 211 wrt.write(String.join(",", labels)); 212 wrt.newLine(); 213 for(Prediction<Label> pred : predictions) { 214 Example<Label> ex = pred.getExample(); 215 wrt.write(ex.getOutput().getLabel()+","); 216 wrt.write(labels 217 .stream() 218 .map(l -> Double.toString(pred 219 .getOutputScores() 220 .get(l).getScore())) 221 .collect(Collectors.joining(","))); 222 wrt.newLine(); 223 } 224 wrt.flush(); 225 } catch (IOException e) { 226 logger.log(Level.SEVERE, "Error writing predictions", e); 227 } 228 } 229 230 } 231 232}