001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.data.DataOptions; 028import org.tribuo.regression.RegressionFactory; 029import org.tribuo.regression.Regressor; 030import org.tribuo.regression.evaluation.RegressionEvaluation; 031import org.tribuo.regression.xgboost.XGBoostRegressionTrainer.RegressionType; 032import org.tribuo.util.Util; 033 034import java.io.IOException; 035import java.util.logging.Logger; 036 037/** 038 * Build and run an XGBoost regressor for a standard dataset. 039 */ 040public class TrainTest { 041 042 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 043 044 public static class XGBoostOptions implements Options { 045 @Override 046 public String getOptionsDescription() { 047 return "Trains and tests an XGBoost regression model on the specified datasets."; 048 } 049 public DataOptions general; 050 051 @Option(longName="regression-metric", usage="Regression type to use. Defaults to LINEAR.") 052 public RegressionType rType = RegressionType.LINEAR; 053 @Option(charName='m',longName="ensemble-size",usage="Number of trees in the ensemble.") 054 public int ensembleSize = -1; 055 @Option(charName='a',longName="alpha",usage="L1 regularization term for weights (default 0).") 056 public float alpha = 0.0f; 057 @Option(longName="min-weight",usage="Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).") 058 public float minWeight = 1; 059 @Option(charName='d',longName="max-depth",usage="Max tree depth (default 6, range (0,inf]).") 060 public int depth = 6; 061 @Option(charName='e',longName="eta",usage="Step size shrinkage parameter (default 0.3, range [0,1]).") 062 public float eta = 0.3f; 063 @Option(longName="subsample-features",usage="Subsample features for each tree (default 1, range (0,1]).") 064 public float subsampleFeatures = 1.0f; 065 @Option(charName='g',longName="gamma",usage="Minimum loss reduction to make a split (default 0, range [0,inf]).") 066 public float gamma = 0.0f; 067 @Option(charName='l',longName="lambda",usage="L2 regularization term for weights (default 1).") 068 public float lambda = 1.0f; 069 @Option(charName='q',longName="quiet",usage="Make the XGBoost training procedure quiet.") 070 public boolean quiet; 071 @Option(longName="subsample",usage="Subsample size for each tree (default 1, range (0,1]).") 072 public float subsample = 1.0f; 073 @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).") 074 public int numThreads = 4; 075 } 076 077 /** 078 * @param args the command line arguments 079 * @throws java.io.IOException if there is any error reading the examples. 080 */ 081 public static void main(String[] args) throws IOException { 082 // 083 // Use the labs format logging. 084 LabsLogFormatter.setAllLogFormatters(); 085 086 XGBoostOptions o = new XGBoostOptions(); 087 ConfigurationManager cm; 088 try { 089 cm = new ConfigurationManager(args,o); 090 } catch (UsageException e) { 091 logger.info(e.getMessage()); 092 return; 093 } 094 095 if (o.general.trainingPath == null || o.general.testingPath == null) { 096 logger.info(cm.usage()); 097 logger.info("Please supply a training path and a testing path"); 098 return; 099 } 100 101 if (o.ensembleSize == -1) { 102 logger.info(cm.usage()); 103 logger.info("Please supply the number of trees."); 104 return; 105 } 106 107 RegressionFactory factory = new RegressionFactory(); 108 109 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 110 Dataset<Regressor> train = data.getA(); 111 Dataset<Regressor> test = data.getB(); 112 113 //public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, long seed) { 114 XGBoostRegressionTrainer trainer = new XGBoostRegressionTrainer(o.rType,o.ensembleSize,o.eta,o.gamma,o.depth,o.minWeight,o.subsample,o.subsampleFeatures,o.lambda,o.alpha,o.numThreads,o.quiet,o.general.seed); 115 logger.info("Training using " + trainer.toString()); 116 final long trainStart = System.currentTimeMillis(); 117 Model<Regressor> model = trainer.train(train); 118 final long trainStop = System.currentTimeMillis(); 119 120 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 121 122 final long testStart = System.currentTimeMillis(); 123 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 124 final long testStop = System.currentTimeMillis(); 125 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 126 System.out.println(evaluation.toString()); 127 128 if (o.general.outputPath != null) { 129 o.general.saveModel(model); 130 } 131 } 132}