001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.sgd; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.Trainer; 028import org.tribuo.data.DataOptions; 029import org.tribuo.math.StochasticGradientOptimiser; 030import org.tribuo.math.optimisers.GradientOptimiserOptions; 031import org.tribuo.regression.RegressionFactory; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.evaluation.RegressionEvaluation; 034import org.tribuo.regression.sgd.linear.LinearSGDTrainer; 035import org.tribuo.regression.sgd.objectives.AbsoluteLoss; 036import org.tribuo.regression.sgd.objectives.Huber; 037import org.tribuo.regression.sgd.objectives.SquaredLoss; 038import org.tribuo.util.Util; 039 040import java.io.IOException; 041import java.util.logging.Logger; 042 043/** 044 * Build and run a linear regression for a standard dataset. 045 */ 046public class TrainTest { 047 048 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 049 050 public enum LossEnum { ABSOLUTE, SQUARED, HUBER } 051 052 public static class SGDOptions implements Options { 053 @Override 054 public String getOptionsDescription() { 055 return "Trains and tests a linear SGD regression model on the specified datasets."; 056 } 057 public DataOptions general; 058 public GradientOptimiserOptions gradientOptions; 059 060 @Option(charName='i',longName="epochs",usage="Number of SGD epochs. Defaults to 5.") 061 public int epochs = 5; 062 @Option(charName='o',longName="objective",usage="Loss function. Defaults to SQUARED.") 063 public LossEnum loss = LossEnum.SQUARED; 064 @Option(charName='p',longName="logging-interval",usage="Log the objective after <int> examples. Defaults to 100.") 065 public int loggingInterval = 100; 066 @Option(charName='z',longName="minibatch-size",usage="Minibatch size. Defaults to 1.") 067 public int minibatchSize = 1; 068 } 069 070 /** 071 * @param args the command line arguments 072 * @throws IOException if there is any error reading the examples. 073 */ 074 public static void main(String[] args) throws IOException { 075 076 // 077 // Use the labs format logging. 078 LabsLogFormatter.setAllLogFormatters(); 079 080 SGDOptions o = new SGDOptions(); 081 ConfigurationManager cm; 082 try { 083 cm = new ConfigurationManager(args,o); 084 } catch (UsageException e) { 085 logger.info(e.getMessage()); 086 return; 087 } 088 089 if (o.general.trainingPath == null || o.general.testingPath == null) { 090 logger.info(cm.usage()); 091 return; 092 } 093 094 logger.info("Configuring gradient optimiser"); 095 RegressionObjective obj = null; 096 switch (o.loss) { 097 case ABSOLUTE: 098 obj = new AbsoluteLoss(); 099 break; 100 case SQUARED: 101 obj = new SquaredLoss(); 102 break; 103 case HUBER: 104 obj = new Huber(); 105 break; 106 default: 107 logger.warning("Unknown objective function " + o.loss); 108 logger.info(cm.usage()); 109 return; 110 } 111 StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser(); 112 113 logger.info(String.format("Set logging interval to %d",o.loggingInterval)); 114 RegressionFactory factory = new RegressionFactory(); 115 116 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 117 Dataset<Regressor> train = data.getA(); 118 Dataset<Regressor> test = data.getB(); 119 120 Trainer<Regressor> trainer = new LinearSGDTrainer(obj,grad,o.epochs,o.loggingInterval,o.minibatchSize,o.general.seed); 121 logger.info("Training using " + trainer.toString()); 122 final long trainStart = System.currentTimeMillis(); 123 Model<Regressor> model = trainer.train(train); 124 final long trainStop = System.currentTimeMillis(); 125 126 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 127 128 final long testStart = System.currentTimeMillis(); 129 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 130 final long testStop = System.currentTimeMillis(); 131 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 132 System.out.println(evaluation.toString()); 133 134 //System.out.println("Features - " + model.getTopFeatures(model.getFeatureIDMap().size()+1)); 135 136 if (o.general.outputPath != null) { 137 o.general.saveModel(model); 138 } 139 } 140}