001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.experiments;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.Trainer;
028import org.tribuo.classification.Label;
029import org.tribuo.classification.LabelFactory;
030import org.tribuo.classification.evaluation.ConfusionMatrix;
031import org.tribuo.classification.evaluation.LabelEvaluation;
032import org.tribuo.classification.evaluation.LabelEvaluator;
033import org.tribuo.data.DataOptions;
034
035import java.io.BufferedWriter;
036import java.io.File;
037import java.io.FileOutputStream;
038import java.io.FileWriter;
039import java.io.IOException;
040import java.io.ObjectOutputStream;
041import java.io.OutputStreamWriter;
042import java.io.PrintWriter;
043import java.nio.charset.StandardCharsets;
044import java.util.HashMap;
045import java.util.List;
046import java.util.Map;
047import java.util.logging.Level;
048import java.util.logging.Logger;
049
050/**
051 * Trains and tests a model using the supplied data, for each trainer inside a configuration file.
052 */
053public class RunAll {
054    private static final Logger logger = Logger.getLogger(RunAll.class.getName());
055
056    public static class RunAllOptions implements Options {
057        @Override
058        public String getOptionsDescription() {
059            return "Performs the same training and test experiment on all Trainers in the supplied configuration file.";
060        }
061        public DataOptions general;
062
063        @Option(charName='d',longName="output-directory",usage="Directory to write out the models and test reports.")
064        public File directory;
065    }
066
067    public static void main(String[] args) throws IOException {
068        LabsLogFormatter.setAllLogFormatters();
069
070        RunAllOptions o = new RunAllOptions();
071        ConfigurationManager cm;
072        try {
073            cm = new ConfigurationManager(args,o);
074        } catch (UsageException e) {
075            logger.info(e.getMessage());
076            return;
077        }
078
079        if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) {
080            logger.info(cm.usage());
081            System.exit(1);
082        }
083        Pair<Dataset<Label>,Dataset<Label>> data = null;
084        try {
085            data = o.general.load(new LabelFactory());
086        } catch (IOException e) {
087            logger.log(Level.SEVERE, "Failed to load data", e);
088            System.exit(1);
089        }
090        Dataset<Label> train = data.getA();
091        Dataset<Label> test = data.getB();
092
093        logger.info("Creating directory - " + o.directory.toString());
094        if (!o.directory.exists() && !o.directory.mkdirs()) {
095            logger.warning("Failed to create directory.");
096        }
097
098        Map<String,Double> performances = new HashMap<>();
099        List<Trainer> trainers = cm.lookupAll(Trainer.class);
100        for (Trainer<?> t : trainers) {
101            String name = t.getClass().getSimpleName();
102            logger.info("Training model using " + t.toString());
103            @SuppressWarnings("unchecked") // configuration system cast.
104            Model<Label> curModel = ((Trainer<Label>)t).train(train);
105            LabelEvaluator evaluator = new LabelEvaluator();
106            LabelEvaluation evaluation = evaluator.evaluate(curModel,test);
107            Double old = performances.put(name,evaluation.microAveragedF1());
108            if (old != null) {
109                logger.info("Found two trainers with the name " + name);
110            }
111            String outputPath = o.directory.toString()+"/"+name;
112            try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath+".model"))) {
113                oos.writeObject(curModel);
114            }
115            try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath+".output"), StandardCharsets.UTF_8))) {
116                writer.println("Model = " + name);
117                writer.println("Provenance = " + curModel.toString());
118                writer.println();
119                ConfusionMatrix<Label> matrix = evaluation.getConfusionMatrix();
120                writer.println("ConfusionMatrix:\n" + matrix.toString());
121                writer.println();
122                writer.println("Evaluation:\n" + evaluation.toString());
123            }
124        }
125
126        for (Map.Entry<String,Double> e : performances.entrySet()) {
127            logger.info("Trainer = " + e.getKey() + ", F1 = " + e.getValue());
128        }
129
130    }
131
132}