001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.Output;
028import org.tribuo.OutputFactory;
029import org.tribuo.Trainer;
030import org.tribuo.evaluation.CrossValidation;
031import org.tribuo.evaluation.DescriptiveStats;
032import org.tribuo.evaluation.Evaluation;
033import org.tribuo.evaluation.EvaluationAggregator;
034import org.tribuo.evaluation.Evaluator;
035import org.tribuo.evaluation.metrics.MetricID;
036import org.tribuo.transform.TransformTrainer;
037import org.tribuo.transform.TransformationMap;
038import org.tribuo.util.Util;
039
040import java.io.IOException;
041import java.util.ArrayList;
042import java.util.Comparator;
043import java.util.List;
044import java.util.Map;
045import java.util.logging.Level;
046import java.util.logging.Logger;
047import java.util.stream.Collectors;
048
049/**
050 * Build and run a predictor for a standard dataset.
051 */
052public final class ConfigurableTrainTest {
053
054    private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName());
055
056    private ConfigurableTrainTest() {}
057
058    public static class ConfigurableTrainTestOptions implements Options {
059        @Override
060        public String getOptionsDescription() {
061            return "Loads a Trainer from a config file, trains a Model (optionally with cross-validation), tests it and optionally saves it to disk.";
062        }
063
064        public DataOptions general;
065
066        @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.")
067        public Trainer<?> trainer;
068
069        @Option(longName="transformer",usage="Load a transformation map from the config file.")
070        public TransformationMap transformationMap;
071
072        @Option(charName='a',longName="output-factory",usage="The output factory to construct.")
073        public OutputFactory<?> outputFactory;
074
075        @Option(charName='x',longName="cross-validate",usage="Cross-validate the output metrics.")
076        public boolean crossValidation;
077
078        @Option(charName='n',longName="num-folds",usage="The number of cross validation folds.")
079        public int numFolds = 5;
080    }
081
082    /**
083     * @param args the command line arguments
084     * @param <T> The {@link Output} subclass.
085     */
086    @SuppressWarnings("unchecked")
087    public static <T extends Output<T>> void main(String[] args) {
088
089        //
090        // Use the labs format logging.
091        LabsLogFormatter.setAllLogFormatters();
092
093        ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
094        ConfigurationManager cm;
095        try {
096            cm = new ConfigurationManager(args,o);
097        } catch (UsageException e) {
098            logger.info(e.getMessage());
099            return;
100        }
101
102        if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
103            logger.info(cm.usage());
104            System.exit(1);
105        }
106
107        Pair<Dataset<T>,Dataset<T>> data = null;
108        try {
109             data = o.general.load((OutputFactory<T>)o.outputFactory);
110        } catch (IOException e) {
111            logger.log(Level.SEVERE, "Failed to load data", e);
112            System.exit(1);
113        }
114        Dataset<T> train = data.getA();
115        Dataset<T> test = data.getB();
116
117        if (o.trainer == null) {
118            logger.warning("No trainer supplied");
119            logger.info(cm.usage());
120            System.exit(1);
121        }
122
123        if (o.transformationMap != null) {
124            o.trainer = new TransformTrainer<>(o.trainer,o.transformationMap);
125        }
126        logger.info("Trainer is " + o.trainer.getProvenance().toString());
127
128        logger.info("Outputs are " + train.getOutputInfo().toReadableString());
129
130        logger.info("Number of features: " + train.getFeatureMap().size());
131
132        final long trainStart = System.currentTimeMillis();
133        Model<T> model = ((Trainer<T>)o.trainer).train(train);
134        final long trainStop = System.currentTimeMillis();
135                
136        logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop));
137
138        Evaluator<T,? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
139        final long testStart = System.currentTimeMillis();
140        Evaluation<T> evaluation = evaluator.evaluate(model,test);
141        final long testStop = System.currentTimeMillis();
142        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
143        System.out.println(evaluation.toString());
144
145        if (o.general.outputPath != null) {
146            try {
147                o.general.saveModel(model);
148            } catch (IOException e) {
149                logger.log(Level.SEVERE, "Error writing model", e);
150            }
151        }
152
153        if (o.crossValidation) {
154            if (o.numFolds > 1) {
155                logger.info("Running " + o.numFolds + " fold cross-validation");
156                CrossValidation<T,? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>)o.trainer,train,evaluator,o.numFolds,o.general.seed);
157                List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate();
158                List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
159                // Summarize across everything
160                Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
161
162                List<MetricID<T>> keys = new ArrayList<>(summary.keySet())
163                        .stream()
164                        .sorted(Comparator.comparing(Pair::getB))
165                        .collect(Collectors.toList());
166                System.out.println("Summary across the folds:");
167                for (MetricID<T> key : keys) {
168                    DescriptiveStats stats = summary.get(key);
169                    System.out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
170                }
171            } else {
172                logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
173            }
174        }
175    }
176}