001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Trainer; 026import org.tribuo.common.xgboost.XGBoostModel; 027import org.tribuo.common.xgboost.XGBoostTrainer; 028import org.tribuo.provenance.ModelProvenance; 029import org.tribuo.provenance.TrainerProvenance; 030import org.tribuo.provenance.impl.TrainerProvenanceImpl; 031import org.tribuo.regression.Regressor; 032import ml.dmlc.xgboost4j.java.Booster; 033import ml.dmlc.xgboost4j.java.XGBoost; 034import ml.dmlc.xgboost4j.java.XGBoostError; 035 036import java.time.OffsetDateTime; 037import java.util.ArrayList; 038import java.util.Collections; 039import java.util.List; 040import java.util.Map; 041import java.util.logging.Level; 042import java.util.logging.Logger; 043 044/** 045 * A {@link Trainer} which wraps the XGBoost training procedure. 046 * This only exposes a few of XGBoost's training parameters. 047 * It uses pthreads outside of the JVM to parallelise the computation. 048 * <p> 049 * Each output dimension is trained independently (and so contains a separate XGBoost ensemble). 050 * <p> 051 * See: 052 * <pre> 053 * Chen T, Guestrin C. 054 * "XGBoost: A Scalable Tree Boosting System" 055 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016. 056 * </pre> 057 * and for the original algorithm: 058 * <pre> 059 * Friedman JH. 060 * "Greedy Function Approximation: a Gradient Boosting Machine" 061 * Annals of statistics, 2001. 062 * </pre> 063 * <p> 064 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), 065 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary 066 * on Maven Central. 067 */ 068public final class XGBoostRegressionTrainer extends XGBoostTrainer<Regressor> { 069 070 private static final Logger logger = Logger.getLogger(XGBoostRegressionTrainer.class.getName()); 071 072 /** 073 * Types of regression loss. 074 */ 075 public enum RegressionType { 076 /** 077 * Squared error loss function. 078 */ 079 LINEAR("reg:squarederror"), 080 /** 081 * Gamma loss function. 082 */ 083 GAMMA("reg:gamma"), 084 /** 085 * Tweedie loss function. 086 */ 087 TWEEDIE("reg:tweedie"); 088 089 public final String paramName; 090 091 RegressionType(String paramName) { 092 this.paramName = paramName; 093 } 094 } 095 096 @Config(description="The type of regression.") 097 private RegressionType rType = RegressionType.LINEAR; 098 099 public XGBoostRegressionTrainer(int numTrees) { 100 this(RegressionType.LINEAR, numTrees); 101 } 102 103 public XGBoostRegressionTrainer(RegressionType rType, int numTrees) { 104 this(rType, numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED); 105 } 106 107 public XGBoostRegressionTrainer(RegressionType rType, int numTrees, int numThreads, boolean silent) { 108 this(rType, numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED); 109 } 110 111 /** 112 * Create an XGBoost trainer. 113 * 114 * @param rType The type of regression to build. 115 * @param numTrees Number of trees to boost. 116 * @param eta Step size shrinkage parameter (default 0.3, range [0,1]). 117 * @param gamma Minimum loss reduction to make a split (default 0, range 118 * [0,inf]). 119 * @param maxDepth Maximum tree depth (default 6, range [1,inf]). 120 * @param minChildWeight Minimum sum of instance weights needed in a leaf 121 * (default 1, range [0, inf]). 122 * @param subsample Subsample size for each tree (default 1, range (0,1]). 123 * @param featureSubsample Subsample features for each tree (default 1, 124 * range (0,1]). 125 * @param lambda L2 regularization term on weights (default 1). 126 * @param alpha L1 regularization term on weights (default 0). 127 * @param nThread Number of threads to use (default 4). 128 * @param silent Silence the training output text. 129 * @param seed RNG seed. 130 */ 131 public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) { 132 super(numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent,seed); 133 this.rType = rType; 134 135 postConfig(); 136 } 137 138 /** 139 * This gives direct access to the XGBoost parameter map. 140 * <p> 141 * It lets you pick things that we haven't exposed like dropout trees, binary classification etc. 142 * <p> 143 * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results. 144 * @param rType The type of the regression. 145 * @param numTrees Number of trees to boost. 146 * @param parameters A map from string to object, where object can be Number or String. 147 */ 148 public XGBoostRegressionTrainer(RegressionType rType, int numTrees, Map<String,Object> parameters) { 149 super(numTrees,parameters); 150 this.rType = rType; 151 postConfig(); 152 } 153 154 /** 155 * For olcut. 156 */ 157 private XGBoostRegressionTrainer() { } 158 159 /** 160 * Used by the OLCUT configuration system, and should not be called by external code. 161 */ 162 @Override 163 public void postConfig() { 164 super.postConfig(); 165 parameters.put("objective",rType.paramName); 166 } 167 168 @Override 169 public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 170 if (examples.getOutputInfo().getUnknownCount() > 0) { 171 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 172 } 173 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 174 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 175 int numOutputs = outputInfo.size(); 176 TrainerProvenance trainerProvenance = getProvenance(); 177 trainInvocationCounter++; 178 List<Booster> models = new ArrayList<>(); 179 try { 180 // Use a null response extractor as we'll do the per dimension regression extraction later. 181 DMatrixTuple<Regressor> trainingData = convertExamples(examples, featureMap, null); 182 183 // Extract the weights and the regression targets. 184 float[][] outputs = new float[numOutputs][examples.size()]; 185 float[] weights = new float[examples.size()]; 186 int i = 0; 187 for (Example<Regressor> e : examples) { 188 weights[i] = e.getWeight(); 189 double[] curOutputs = e.getOutput().getValues(); 190 // Transpose them for easy training. 191 for (int j = 0; j < numOutputs; j++) { 192 outputs[j][i] = (float) curOutputs[j]; 193 } 194 i++; 195 } 196 trainingData.data.setWeight(weights); 197 198 // Finished setup, now train one model per dimension. 199 for (i = 0; i < numOutputs; i++) { 200 trainingData.data.setLabel(outputs[i]); 201 models.add(XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null)); 202 } 203 } catch (XGBoostError e) { 204 logger.log(Level.SEVERE, "XGBoost threw an error", e); 205 throw new IllegalStateException(e); 206 } 207 208 ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 209 XGBoostModel<Regressor> xgModel = createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter()); 210 211 return xgModel; 212 } 213 214 @Override 215 public TrainerProvenance getProvenance() { 216 return new TrainerProvenanceImpl(this); 217 } 218}