001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification;
018
019import com.oracle.labs.mlrg.olcut.config.ArgumentException;
020import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
021import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.Dataset;
024import org.tribuo.Model;
025import org.tribuo.Trainer;
026import org.tribuo.classification.evaluation.LabelEvaluation;
027import org.tribuo.data.DataOptions;
028import org.tribuo.util.Util;
029
030import java.io.IOException;
031import java.util.logging.Logger;
032
033/**
034 * This class provides static methods used by the demo classes in each classification backend.
035 */
036public final class TrainTestHelper {
037
038    private static final Logger logger = Logger.getLogger(TrainTestHelper.class.getName());
039
040    private static final LabelFactory factory = new LabelFactory();
041
042    private TrainTestHelper() { }
043
044    /**
045     * This method trains a model on the specified training data, and evaluates it
046     * on the specified test data. It writes out the timing to it's logger, and the
047     * statistical performance to standard out. If set, the model is written out
048     * to the specified path on disk.
049     * @param cm The configuration manager which knows the arguments.
050     * @param dataOptions The data options which specify the training and test data.
051     * @param trainer The trainer to use.
052     * @return The trained model.
053     * @throws IOException If the data failed to load.
054     */
055    public static Model<Label> run(ConfigurationManager cm, DataOptions dataOptions, Trainer<Label> trainer) throws IOException {
056        LabsLogFormatter.setAllLogFormatters();
057
058        if (dataOptions.trainingPath == null || dataOptions.testingPath == null) {
059            logger.info(cm.usage());
060            logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath);
061            throw new ArgumentException("training-file","test-file","Must supply both training and testing data.");
062        }
063
064        Pair<Dataset<Label>, Dataset<Label>> data = dataOptions.load(factory);
065        Dataset<Label> train = data.getA();
066        logger.info("Training data has " + train.getFeatureIDMap().size() + " features.");
067
068        Dataset<Label> test = data.getB();
069
070        logger.info("Training using " + trainer.toString());
071        final long trainStart = System.currentTimeMillis();
072        Model<Label> model = trainer.train(train);
073        final long trainStop = System.currentTimeMillis();
074        logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
075        final long testStart = System.currentTimeMillis();
076        LabelEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
077        final long testStop = System.currentTimeMillis();
078        logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
079
080        if (model.generatesProbabilities()) {
081            logger.info("Average AUC = " + evaluation.averageAUCROC(false));
082            logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
083        }
084
085        System.out.println(evaluation.toString());
086
087        System.out.println(evaluation.getConfusionMatrix().toString());
088
089        if (dataOptions.outputPath != null) {
090            dataOptions.saveModel(model);
091        }
092
093        return model;
094    }
095}