001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.data.DataOptions;
028import org.tribuo.regression.RegressionFactory;
029import org.tribuo.regression.Regressor;
030import org.tribuo.regression.evaluation.RegressionEvaluation;
031import org.tribuo.regression.xgboost.XGBoostRegressionTrainer.RegressionType;
032import org.tribuo.util.Util;
033
034import java.io.IOException;
035import java.util.logging.Logger;
036
037/**
038 * Build and run an XGBoost regressor for a standard dataset.
039 */
040public class TrainTest {
041
042    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
043    
044    public static class XGBoostOptions implements Options {
045        @Override
046        public String getOptionsDescription() {
047            return "Trains and tests an XGBoost regression model on the specified datasets.";
048        }
049        public DataOptions general;
050
051        @Option(longName="regression-metric", usage="Regression type to use. Defaults to LINEAR.")
052        public RegressionType rType = RegressionType.LINEAR;
053        @Option(charName='m',longName="ensemble-size",usage="Number of trees in the ensemble.")
054        public int ensembleSize = -1;
055        @Option(charName='a',longName="alpha",usage="L1 regularization term for weights (default 0).")
056        public float alpha = 0.0f;
057        @Option(longName="min-weight",usage="Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
058        public float minWeight = 1;
059        @Option(charName='d',longName="max-depth",usage="Max tree depth (default 6, range (0,inf]).")
060        public int depth = 6;
061        @Option(charName='e',longName="eta",usage="Step size shrinkage parameter (default 0.3, range [0,1]).")
062        public float eta = 0.3f;
063        @Option(longName="subsample-features",usage="Subsample features for each tree (default 1, range (0,1]).")
064        public float subsampleFeatures = 1.0f;
065        @Option(charName='g',longName="gamma",usage="Minimum loss reduction to make a split (default 0, range [0,inf]).")
066        public float gamma = 0.0f;
067        @Option(charName='l',longName="lambda",usage="L2 regularization term for weights (default 1).")
068        public float lambda = 1.0f;
069        @Option(charName='q',longName="quiet",usage="Make the XGBoost training procedure quiet.")
070        public boolean quiet;
071        @Option(longName="subsample",usage="Subsample size for each tree (default 1, range (0,1]).")
072        public float subsample = 1.0f;
073        @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).")
074        public int numThreads = 4;
075    }
076
077    /**
078     * @param args the command line arguments
079     * @throws java.io.IOException if there is any error reading the examples.
080     */
081    public static void main(String[] args) throws IOException {
082        //
083        // Use the labs format logging.
084        LabsLogFormatter.setAllLogFormatters();
085
086        XGBoostOptions o = new XGBoostOptions();
087        ConfigurationManager cm;
088        try {
089            cm = new ConfigurationManager(args,o);
090        } catch (UsageException e) {
091            logger.info(e.getMessage());
092            return;
093        }
094
095        if (o.general.trainingPath == null || o.general.testingPath == null) {
096            logger.info(cm.usage());
097            logger.info("Please supply a training path and a testing path");
098            return;
099        }
100
101        if (o.ensembleSize == -1) {
102            logger.info(cm.usage());
103            logger.info("Please supply the number of trees.");
104            return;
105        }
106
107        RegressionFactory factory = new RegressionFactory();
108
109        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
110        Dataset<Regressor> train = data.getA();
111        Dataset<Regressor> test = data.getB();
112
113        //public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, long seed) {
114        XGBoostRegressionTrainer trainer = new XGBoostRegressionTrainer(o.rType,o.ensembleSize,o.eta,o.gamma,o.depth,o.minWeight,o.subsample,o.subsampleFeatures,o.lambda,o.alpha,o.numThreads,o.quiet,o.general.seed);
115        logger.info("Training using " + trainer.toString());
116        final long trainStart = System.currentTimeMillis();
117        Model<Regressor> model = trainer.train(train);
118        final long trainStop = System.currentTimeMillis();
119
120        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
121
122        final long testStart = System.currentTimeMillis();
123        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
124        final long testStop = System.currentTimeMillis();
125        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
126        System.out.println(evaluation.toString());
127
128        if (o.general.outputPath != null) {
129            o.general.saveModel(model);
130        }
131    }
132}