001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification; 018 019import com.oracle.labs.mlrg.olcut.config.ArgumentException; 020import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 021import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.Dataset; 024import org.tribuo.Model; 025import org.tribuo.Trainer; 026import org.tribuo.classification.evaluation.LabelEvaluation; 027import org.tribuo.data.DataOptions; 028import org.tribuo.util.Util; 029 030import java.io.IOException; 031import java.util.logging.Logger; 032 033/** 034 * This class provides static methods used by the demo classes in each classification backend. 035 */ 036public final class TrainTestHelper { 037 038 private static final Logger logger = Logger.getLogger(TrainTestHelper.class.getName()); 039 040 private static final LabelFactory factory = new LabelFactory(); 041 042 private TrainTestHelper() { } 043 044 /** 045 * This method trains a model on the specified training data, and evaluates it 046 * on the specified test data. It writes out the timing to it's logger, and the 047 * statistical performance to standard out. If set, the model is written out 048 * to the specified path on disk. 049 * @param cm The configuration manager which knows the arguments. 050 * @param dataOptions The data options which specify the training and test data. 051 * @param trainer The trainer to use. 052 * @return The trained model. 053 * @throws IOException If the data failed to load. 054 */ 055 public static Model<Label> run(ConfigurationManager cm, DataOptions dataOptions, Trainer<Label> trainer) throws IOException { 056 LabsLogFormatter.setAllLogFormatters(); 057 058 if (dataOptions.trainingPath == null || dataOptions.testingPath == null) { 059 logger.info(cm.usage()); 060 logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath); 061 throw new ArgumentException("training-file","test-file","Must supply both training and testing data."); 062 } 063 064 Pair<Dataset<Label>, Dataset<Label>> data = dataOptions.load(factory); 065 Dataset<Label> train = data.getA(); 066 logger.info("Training data has " + train.getFeatureIDMap().size() + " features."); 067 068 Dataset<Label> test = data.getB(); 069 070 logger.info("Training using " + trainer.toString()); 071 final long trainStart = System.currentTimeMillis(); 072 Model<Label> model = trainer.train(train); 073 final long trainStop = System.currentTimeMillis(); 074 logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop)); 075 final long testStart = System.currentTimeMillis(); 076 LabelEvaluation evaluation = factory.getEvaluator().evaluate(model, test); 077 final long testStop = System.currentTimeMillis(); 078 logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop)); 079 080 if (model.generatesProbabilities()) { 081 logger.info("Average AUC = " + evaluation.averageAUCROC(false)); 082 logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true)); 083 } 084 085 System.out.println(evaluation.toString()); 086 087 System.out.println(evaluation.getConfusionMatrix().toString()); 088 089 if (dataOptions.outputPath != null) { 090 dataOptions.saveModel(model); 091 } 092 093 return model; 094 } 095}