001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.sgd;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.Trainer;
028import org.tribuo.data.DataOptions;
029import org.tribuo.math.StochasticGradientOptimiser;
030import org.tribuo.math.optimisers.GradientOptimiserOptions;
031import org.tribuo.regression.RegressionFactory;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.evaluation.RegressionEvaluation;
034import org.tribuo.regression.sgd.linear.LinearSGDTrainer;
035import org.tribuo.regression.sgd.objectives.AbsoluteLoss;
036import org.tribuo.regression.sgd.objectives.Huber;
037import org.tribuo.regression.sgd.objectives.SquaredLoss;
038import org.tribuo.util.Util;
039
040import java.io.IOException;
041import java.util.logging.Logger;
042
043/**
044 * Build and run a linear regression for a standard dataset.
045 */
046public class TrainTest {
047
048    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
049
050    public enum LossEnum { ABSOLUTE, SQUARED, HUBER }
051
052    public static class SGDOptions implements Options {
053        @Override
054        public String getOptionsDescription() {
055            return "Trains and tests a linear SGD regression model on the specified datasets.";
056        }
057        public DataOptions general;
058        public GradientOptimiserOptions gradientOptions;
059
060        @Option(charName='i',longName="epochs",usage="Number of SGD epochs. Defaults to 5.")
061        public int epochs = 5;
062        @Option(charName='o',longName="objective",usage="Loss function. Defaults to SQUARED.")
063        public LossEnum loss = LossEnum.SQUARED;
064        @Option(charName='p',longName="logging-interval",usage="Log the objective after <int> examples. Defaults to 100.")
065        public int loggingInterval = 100;
066        @Option(charName='z',longName="minibatch-size",usage="Minibatch size. Defaults to 1.")
067        public int minibatchSize = 1;
068    }
069
070    /**
071     * @param args the command line arguments
072     * @throws IOException if there is any error reading the examples.
073     */
074    public static void main(String[] args) throws IOException {
075
076        //
077        // Use the labs format logging.
078        LabsLogFormatter.setAllLogFormatters();
079
080        SGDOptions o = new SGDOptions();
081        ConfigurationManager cm;
082        try {
083            cm = new ConfigurationManager(args,o);
084        } catch (UsageException e) {
085            logger.info(e.getMessage());
086            return;
087        }
088
089        if (o.general.trainingPath == null || o.general.testingPath == null) {
090            logger.info(cm.usage());
091            return;
092        }
093
094        logger.info("Configuring gradient optimiser");
095        RegressionObjective obj = null;
096        switch (o.loss) {
097            case ABSOLUTE:
098                obj = new AbsoluteLoss();
099                break;
100            case SQUARED:
101                obj = new SquaredLoss();
102                break;
103            case HUBER:
104                obj = new Huber();
105                break;
106            default:
107                logger.warning("Unknown objective function " + o.loss);
108                logger.info(cm.usage());
109                return;
110        }
111        StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
112
113        logger.info(String.format("Set logging interval to %d",o.loggingInterval));
114        RegressionFactory factory = new RegressionFactory();
115
116        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
117        Dataset<Regressor> train = data.getA();
118        Dataset<Regressor> test = data.getB();
119
120        Trainer<Regressor> trainer = new LinearSGDTrainer(obj,grad,o.epochs,o.loggingInterval,o.minibatchSize,o.general.seed);
121        logger.info("Training using " + trainer.toString());
122        final long trainStart = System.currentTimeMillis();
123        Model<Regressor> model = trainer.train(train);
124        final long trainStop = System.currentTimeMillis();
125
126        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
127
128        final long testStart = System.currentTimeMillis();
129        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
130        final long testStop = System.currentTimeMillis();
131        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
132        System.out.println(evaluation.toString());
133
134        //System.out.println("Features - " + model.getTopFeatures(model.getFeatureIDMap().size()+1));
135
136        if (o.general.outputPath != null) {
137            o.general.saveModel(model);
138        }
139    }
140}