001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.experiments; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Example; 027import org.tribuo.Model; 028import org.tribuo.MutableDataset; 029import org.tribuo.Prediction; 030import org.tribuo.Trainer; 031import org.tribuo.WeightedExamples; 032import org.tribuo.classification.Label; 033import org.tribuo.classification.LabelFactory; 034import org.tribuo.classification.WeightedLabels; 035import org.tribuo.classification.evaluation.ConfusionMatrix; 036import org.tribuo.classification.evaluation.LabelEvaluation; 037import org.tribuo.classification.evaluation.LabelEvaluator; 038import org.tribuo.data.DataOptions; 039import org.tribuo.util.Util; 040 041import java.io.BufferedWriter; 042import java.io.IOException; 043import java.nio.file.Files; 044import java.nio.file.Path; 045import java.util.HashMap; 046import java.util.List; 047import java.util.Map; 048import java.util.logging.Level; 049import java.util.logging.Logger; 050import java.util.stream.Collectors; 051 052/** 053 * Build and run a classifier for a standard dataset. 054 */ 055public class ConfigurableTrainTest { 056 057 private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName()); 058 059 public static class ConfigurableTrainTestOptions implements Options { 060 @Override 061 public String getOptionsDescription() { 062 return "Loads a Trainer (and optionally a Datasource) from a config file, trains a Model, tests it and optionally saves it to disk."; 063 } 064 065 public DataOptions general; 066 067 @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.") 068 public Trainer<Label> trainer; 069 070 @Option(charName='w',longName="weights",usage="A list of weights to use in classification. Format = LABEL_NAME:weight,LABEL_NAME:weight...") 071 public List<String> weights; 072 073 @Option(charName='o',longName="predictions",usage="Path to write model predictions") 074 public Path predictionPath; 075 } 076 077 /** 078 * Converts the weight text input format into an object suitable for use in a Trainer. 079 * @param input The input form. 080 * @return The weights. 081 */ 082 public static Map<Label,Float> processWeights(List<String> input) { 083 Map<Label,Float> map = new HashMap<>(); 084 085 for (String tuple : input) { 086 String[] splitTuple = tuple.split(":"); 087 map.put(new Label(splitTuple[0]),Float.parseFloat(splitTuple[1])); 088 } 089 090 return map; 091 } 092 093 /** 094 * @param args the command line arguments 095 */ 096 public static void main(String[] args) { 097 098 // 099 // Use the labs format logging. 100 LabsLogFormatter.setAllLogFormatters(); 101 102 ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions(); 103 ConfigurationManager cm; 104 try { 105 cm = new ConfigurationManager(args,o); 106 } catch (UsageException e) { 107 logger.info(e.getMessage()); 108 return; 109 } 110 111 if (o.general.trainingPath == null || o.general.testingPath == null) { 112 logger.info(cm.usage()); 113 System.exit(1); 114 } 115 Pair<Dataset<Label>,Dataset<Label>> data = null; 116 try { 117 data = o.general.load(new LabelFactory()); 118 } catch (IOException e) { 119 logger.log(Level.SEVERE, "Failed to load data", e); 120 System.exit(1); 121 } 122 Dataset<Label> train = data.getA(); 123 Dataset<Label> test = data.getB(); 124 125 if (o.trainer == null) { 126 logger.warning("No trainer supplied"); 127 logger.info(cm.usage()); 128 System.exit(1); 129 } 130 logger.info("Trainer is " + o.trainer.toString()); 131 132 if (o.weights != null) { 133 Map<Label,Float> weightsMap = processWeights(o.weights); 134 if (o.trainer instanceof WeightedLabels) { 135 ((WeightedLabels) o.trainer).setLabelWeights(weightsMap); 136 logger.info("Setting label weights using " + weightsMap.toString()); 137 } else if (o.trainer instanceof WeightedExamples) { 138 ((MutableDataset<Label>)train).setWeights(weightsMap); 139 logger.info("Setting example weights using " + weightsMap.toString()); 140 } else { 141 logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString()); 142 logger.info(cm.usage()); 143 System.exit(1); 144 } 145 } 146 147 logger.info("Labels are " + train.getOutputInfo().toReadableString()); 148 149 final long trainStart = System.currentTimeMillis(); 150 Model<Label> model = o.trainer.train(train); 151 final long trainStop = System.currentTimeMillis(); 152 153 logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop)); 154 155 LabelEvaluator labelEvaluator = new LabelEvaluator(); 156 final long testStart = System.currentTimeMillis(); 157 List<Prediction<Label>> predictions = model.predict(test); 158 LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model,predictions,test.getProvenance()); 159 final long testStop = System.currentTimeMillis(); 160 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 161 System.out.println(labelEvaluation.toString()); 162 ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix(); 163 System.out.println(matrix.toString()); 164 if (model.generatesProbabilities()) { 165 System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false)); 166 System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true)); 167 } 168 169 if(o.predictionPath!=null) { 170 try(BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) { 171 List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList()); 172 wrt.write("Label,"); 173 wrt.write(String.join(",", labels)); 174 wrt.newLine(); 175 for(Prediction<Label> pred : predictions) { 176 Example<Label> ex = pred.getExample(); 177 wrt.write(ex.getOutput().getLabel()+","); 178 wrt.write(labels 179 .stream() 180 .map(l -> Double.toString(pred 181 .getOutputScores() 182 .get(l).getScore())) 183 .collect(Collectors.joining(","))); 184 wrt.newLine(); 185 } 186 wrt.flush(); 187 } catch (IOException e) { 188 logger.log(Level.SEVERE, "Error writing predictions", e); 189 } 190 } 191 192 if (o.general.outputPath != null) { 193 try { 194 o.general.saveModel(model); 195 } catch (IOException e) { 196 logger.log(Level.SEVERE, "Error writing model", e); 197 } 198 } 199 } 200}