001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.experiments;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Example;
027import org.tribuo.Model;
028import org.tribuo.MutableDataset;
029import org.tribuo.Prediction;
030import org.tribuo.Trainer;
031import org.tribuo.WeightedExamples;
032import org.tribuo.classification.Label;
033import org.tribuo.classification.LabelFactory;
034import org.tribuo.classification.WeightedLabels;
035import org.tribuo.classification.evaluation.ConfusionMatrix;
036import org.tribuo.classification.evaluation.LabelEvaluation;
037import org.tribuo.classification.evaluation.LabelEvaluator;
038import org.tribuo.data.DataOptions;
039import org.tribuo.util.Util;
040
041import java.io.BufferedWriter;
042import java.io.IOException;
043import java.nio.file.Files;
044import java.nio.file.Path;
045import java.util.HashMap;
046import java.util.List;
047import java.util.Map;
048import java.util.logging.Level;
049import java.util.logging.Logger;
050import java.util.stream.Collectors;
051
052/**
053 * Build and run a classifier for a standard dataset.
054 */
055public class ConfigurableTrainTest {
056
057    private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName());
058
059    public static class ConfigurableTrainTestOptions implements Options {
060        @Override
061        public String getOptionsDescription() {
062            return "Loads a Trainer (and optionally a Datasource) from a config file, trains a Model, tests it and optionally saves it to disk.";
063        }
064
065        public DataOptions general;
066
067        @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.")
068        public Trainer<Label> trainer;
069
070        @Option(charName='w',longName="weights",usage="A list of weights to use in classification. Format = LABEL_NAME:weight,LABEL_NAME:weight...")
071        public List<String> weights;
072
073        @Option(charName='o',longName="predictions",usage="Path to write model predictions")
074        public Path predictionPath;
075    }
076
077    /**
078     * Converts the weight text input format into an object suitable for use in a Trainer.
079     * @param input The input form.
080     * @return The weights.
081     */
082    public static Map<Label,Float> processWeights(List<String> input) {
083        Map<Label,Float> map = new HashMap<>();
084
085        for (String tuple : input) {
086            String[] splitTuple = tuple.split(":");
087            map.put(new Label(splitTuple[0]),Float.parseFloat(splitTuple[1]));
088        }
089
090        return map;
091    }
092
093    /**
094     * @param args the command line arguments
095     */
096    public static void main(String[] args) {
097
098        //
099        // Use the labs format logging.
100        LabsLogFormatter.setAllLogFormatters();
101
102        ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
103        ConfigurationManager cm;
104        try {
105            cm = new ConfigurationManager(args,o);
106        } catch (UsageException e) {
107            logger.info(e.getMessage());
108            return;
109        }
110
111        if (o.general.trainingPath == null || o.general.testingPath == null) {
112            logger.info(cm.usage());
113            System.exit(1);
114        }
115        Pair<Dataset<Label>,Dataset<Label>> data = null;
116        try {
117             data = o.general.load(new LabelFactory());
118        } catch (IOException e) {
119            logger.log(Level.SEVERE, "Failed to load data", e);
120            System.exit(1);
121        }
122        Dataset<Label> train = data.getA();
123        Dataset<Label> test = data.getB();
124
125        if (o.trainer == null) {
126            logger.warning("No trainer supplied");
127            logger.info(cm.usage());
128            System.exit(1);
129        }
130        logger.info("Trainer is " + o.trainer.toString());
131
132        if (o.weights != null) {
133            Map<Label,Float> weightsMap = processWeights(o.weights);
134            if (o.trainer instanceof WeightedLabels) {
135                ((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
136                logger.info("Setting label weights using " + weightsMap.toString());
137            } else if (o.trainer instanceof WeightedExamples) {
138                ((MutableDataset<Label>)train).setWeights(weightsMap);
139                logger.info("Setting example weights using " + weightsMap.toString());
140            } else {
141                logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
142                logger.info(cm.usage());
143                System.exit(1);
144            }
145        }
146
147        logger.info("Labels are " + train.getOutputInfo().toReadableString());
148
149        final long trainStart = System.currentTimeMillis();
150        Model<Label> model = o.trainer.train(train);
151        final long trainStop = System.currentTimeMillis();
152                
153        logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop));
154
155        LabelEvaluator labelEvaluator = new LabelEvaluator();
156        final long testStart = System.currentTimeMillis();
157        List<Prediction<Label>> predictions = model.predict(test);
158        LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model,predictions,test.getProvenance());
159        final long testStop = System.currentTimeMillis();
160        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
161        System.out.println(labelEvaluation.toString());
162        ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
163        System.out.println(matrix.toString());
164        if (model.generatesProbabilities()) {
165            System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
166            System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
167        }
168
169        if(o.predictionPath!=null) {
170            try(BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
171                List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
172                wrt.write("Label,");
173                wrt.write(String.join(",", labels));
174                wrt.newLine();
175                for(Prediction<Label> pred : predictions) {
176                    Example<Label> ex = pred.getExample();
177                    wrt.write(ex.getOutput().getLabel()+",");
178                    wrt.write(labels
179                            .stream()
180                            .map(l -> Double.toString(pred
181                                    .getOutputScores()
182                                    .get(l).getScore()))
183                            .collect(Collectors.joining(",")));
184                    wrt.newLine();
185                }
186                wrt.flush();
187            } catch (IOException e) {
188                logger.log(Level.SEVERE, "Error writing predictions", e);
189            }
190        }
191
192        if (o.general.outputPath != null) {
193            try {
194                o.general.saveModel(model);
195            } catch (IOException e) {
196                logger.log(Level.SEVERE, "Error writing model", e);
197            }
198        }
199    }
200}