001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.Output; 028import org.tribuo.OutputFactory; 029import org.tribuo.Trainer; 030import org.tribuo.evaluation.CrossValidation; 031import org.tribuo.evaluation.DescriptiveStats; 032import org.tribuo.evaluation.Evaluation; 033import org.tribuo.evaluation.EvaluationAggregator; 034import org.tribuo.evaluation.Evaluator; 035import org.tribuo.evaluation.metrics.MetricID; 036import org.tribuo.transform.TransformTrainer; 037import org.tribuo.transform.TransformationMap; 038import org.tribuo.util.Util; 039 040import java.io.IOException; 041import java.util.ArrayList; 042import java.util.Comparator; 043import java.util.List; 044import java.util.Map; 045import java.util.logging.Level; 046import java.util.logging.Logger; 047import java.util.stream.Collectors; 048 049/** 050 * Build and run a predictor for a standard dataset. 051 */ 052public final class ConfigurableTrainTest { 053 054 private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName()); 055 056 private ConfigurableTrainTest() {} 057 058 public static class ConfigurableTrainTestOptions implements Options { 059 @Override 060 public String getOptionsDescription() { 061 return "Loads a Trainer from a config file, trains a Model (optionally with cross-validation), tests it and optionally saves it to disk."; 062 } 063 064 public DataOptions general; 065 066 @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.") 067 public Trainer<?> trainer; 068 069 @Option(longName="transformer",usage="Load a transformation map from the config file.") 070 public TransformationMap transformationMap; 071 072 @Option(charName='a',longName="output-factory",usage="The output factory to construct.") 073 public OutputFactory<?> outputFactory; 074 075 @Option(charName='x',longName="cross-validate",usage="Cross-validate the output metrics.") 076 public boolean crossValidation; 077 078 @Option(charName='n',longName="num-folds",usage="The number of cross validation folds.") 079 public int numFolds = 5; 080 } 081 082 /** 083 * @param args the command line arguments 084 * @param <T> The {@link Output} subclass. 085 */ 086 @SuppressWarnings("unchecked") 087 public static <T extends Output<T>> void main(String[] args) { 088 089 // 090 // Use the labs format logging. 091 LabsLogFormatter.setAllLogFormatters(); 092 093 ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions(); 094 ConfigurationManager cm; 095 try { 096 cm = new ConfigurationManager(args,o); 097 } catch (UsageException e) { 098 logger.info(e.getMessage()); 099 return; 100 } 101 102 if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) { 103 logger.info(cm.usage()); 104 System.exit(1); 105 } 106 107 Pair<Dataset<T>,Dataset<T>> data = null; 108 try { 109 data = o.general.load((OutputFactory<T>)o.outputFactory); 110 } catch (IOException e) { 111 logger.log(Level.SEVERE, "Failed to load data", e); 112 System.exit(1); 113 } 114 Dataset<T> train = data.getA(); 115 Dataset<T> test = data.getB(); 116 117 if (o.trainer == null) { 118 logger.warning("No trainer supplied"); 119 logger.info(cm.usage()); 120 System.exit(1); 121 } 122 123 if (o.transformationMap != null) { 124 o.trainer = new TransformTrainer<>(o.trainer,o.transformationMap); 125 } 126 logger.info("Trainer is " + o.trainer.getProvenance().toString()); 127 128 logger.info("Outputs are " + train.getOutputInfo().toReadableString()); 129 130 logger.info("Number of features: " + train.getFeatureMap().size()); 131 132 final long trainStart = System.currentTimeMillis(); 133 Model<T> model = ((Trainer<T>)o.trainer).train(train); 134 final long trainStop = System.currentTimeMillis(); 135 136 logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop)); 137 138 Evaluator<T,? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator(); 139 final long testStart = System.currentTimeMillis(); 140 Evaluation<T> evaluation = evaluator.evaluate(model,test); 141 final long testStop = System.currentTimeMillis(); 142 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 143 System.out.println(evaluation.toString()); 144 145 if (o.general.outputPath != null) { 146 try { 147 o.general.saveModel(model); 148 } catch (IOException e) { 149 logger.log(Level.SEVERE, "Error writing model", e); 150 } 151 } 152 153 if (o.crossValidation) { 154 if (o.numFolds > 1) { 155 logger.info("Running " + o.numFolds + " fold cross-validation"); 156 CrossValidation<T,? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>)o.trainer,train,evaluator,o.numFolds,o.general.seed); 157 List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate(); 158 List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList()); 159 // Summarize across everything 160 Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals); 161 162 List<MetricID<T>> keys = new ArrayList<>(summary.keySet()) 163 .stream() 164 .sorted(Comparator.comparing(Pair::getB)) 165 .collect(Collectors.toList()); 166 System.out.println("Summary across the folds:"); 167 for (MetricID<T> key : keys) { 168 DescriptiveStats stats = summary.get(key); 169 System.out.printf("%-10s %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation()); 170 } 171 } else { 172 logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds); 173 } 174 } 175 } 176}