001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.Feature; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Output; 027import org.tribuo.Trainer; 028import org.tribuo.WeightedExamples; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.provenance.SkeletalTrainerProvenance; 033import org.tribuo.util.Util; 034import ml.dmlc.xgboost4j.java.Booster; 035import ml.dmlc.xgboost4j.java.DMatrix; 036import ml.dmlc.xgboost4j.java.XGBoostError; 037 038import java.util.ArrayList; 039import java.util.HashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.function.Function; 043import java.util.logging.Logger; 044 045/** 046 * A {@link Trainer} which wraps the XGBoost training procedure. 047 * <p> 048 * This only exposes a few of XGBoost's training parameters. 049 * <p> 050 * It uses pthreads outside of the JVM to parallelise the computation. 051 * <p> 052 * See: 053 * <pre> 054 * Chen T, Guestrin C. 055 * "XGBoost: A Scalable Tree Boosting System" 056 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016. 057 * </pre> 058 * and for the original algorithm: 059 * <pre> 060 * Friedman JH. 061 * "Greedy Function Approximation: a Gradient Boosting Machine" 062 * Annals of statistics, 2001. 063 * </pre> 064 * N.B.: This uses a native C implementation of xgboost that links to various C libraries, including libgomp 065 * and glibc. If you're running on Alpine, which does not natively use glibc, you'll need to install glibc 066 * into the container. On Windows this binary is not available in the Maven Central release, you'll need 067 * to compile it from source. 068 */ 069public abstract class XGBoostTrainer<T extends Output<T>> implements Trainer<T>, WeightedExamples { 070 /* Alpine install command 071 * <pre> 072 * $ apk --no-cache add ca-certificates wget 073 * $ wget -q -O /etc/apk/keys/sgerrand.rsa.pub https://alpine-pkgs.sgerrand.com/sgerrand.rsa.pub 074 * $ wget https://github.com/sgerrand/alpine-pkg-glibc/releases/download/2.30-r0/glibc-2.30-r0.apk 075 * $ apk add glibc-2.30-r0.apk 076 * </pre> 077 */ 078 079 private static final Logger logger = Logger.getLogger(XGBoostTrainer.class.getName()); 080 081 protected final Map<String, Object> parameters = new HashMap<>(); 082 083 /** 084 * The type of XGBoost model. 085 */ 086 public enum BoosterType { 087 /** 088 * A boosted linear model. 089 */ 090 LINEAR("gblinear"), 091 /** 092 * A gradient boosted decision tree. 093 */ 094 GBTREE("gbtree"), 095 /** 096 * A gradient boosted decision tree using dropout. 097 */ 098 DART("dart"); 099 100 public final String paramName; 101 102 BoosterType(String paramName) { 103 this.paramName = paramName; 104 } 105 } 106 107 @Config(mandatory = true,description="The number of trees to build.") 108 protected int numTrees; 109 110 @Config(description = "The learning rate, shrinks the new tree output to prevent overfitting.") 111 private double eta = 0.3; 112 113 @Config(description = "Minimum loss reduction needed to split a tree node.") 114 private double gamma = 0.0; 115 116 @Config(description="The maximum depth of any tree.") 117 private int maxDepth = 6; 118 119 @Config(description = "The minimum weight in each child node before a split is valid.") 120 private double minChildWeight = 1.0; 121 122 @Config(description="Independently subsample the examples for each tree.") 123 private double subsample = 1.0; 124 125 @Config(description="Independently subsample the features available for each node of each tree.") 126 private double featureSubsample = 1.0; 127 128 @Config(description="l2 regularisation term on the weights.") 129 private double lambda = 1.0; 130 131 @Config(description="l1 regularisation term on the weights.") 132 private double alpha = 1.0; 133 134 @Config(description="The number of threads to use at training time.") 135 private int nThread = 4; 136 137 @Config(description="Quiesce all the logging output from the XGBoost C library.") 138 private int silent = 1; 139 140 @Config(description="Type of the weak learner.") 141 private BoosterType booster = BoosterType.GBTREE; 142 143 @Config(description="The RNG seed.") 144 private long seed = Trainer.DEFAULT_SEED; 145 146 protected int trainInvocationCounter = 0; 147 148 protected XGBoostTrainer(int numTrees) { 149 this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED); 150 } 151 152 protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) { 153 this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED); 154 } 155 156 /** 157 * Create an XGBoost trainer. 158 * 159 * @param numTrees Number of trees to boost. 160 * @param eta Step size shrinkage parameter (default 0.3, range [0,1]). 161 * @param gamma Minimum loss reduction to make a split (default 0, range 162 * [0,inf]). 163 * @param maxDepth Maximum tree depth (default 6, range [1,inf]). 164 * @param minChildWeight Minimum sum of instance weights needed in a leaf 165 * (default 1, range [0, inf]). 166 * @param subsample Subsample size for each tree (default 1, range (0,1]). 167 * @param featureSubsample Subsample features for each tree (default 1, 168 * range (0,1]). 169 * @param lambda L2 regularization term on weights (default 1). 170 * @param alpha L1 regularization term on weights (default 0). 171 * @param nThread Number of threads to use (default 4). 172 * @param silent Silence the training output text. 173 * @param seed RNG seed. 174 */ 175 protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) { 176 if (numTrees < 1) { 177 throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees); 178 } 179 this.numTrees = numTrees; 180 this.eta = eta; 181 this.gamma = gamma; 182 this.maxDepth = maxDepth; 183 this.minChildWeight = minChildWeight; 184 this.subsample = subsample; 185 this.featureSubsample = featureSubsample; 186 this.lambda = lambda; 187 this.alpha = alpha; 188 this.nThread = nThread; 189 this.silent = silent ? 1 : 0; 190 this.seed = seed; 191 } 192 193 /** 194 * This gives direct access to the XGBoost parameter map. 195 * <p> 196 * It lets you pick things that we haven't exposed like dropout trees, binary classification etc. 197 * <p> 198 * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results. 199 * @param numTrees Number of trees to boost. 200 * @param parameters A map from string to object, where object can be Number or String. 201 */ 202 protected XGBoostTrainer(int numTrees, Map<String,Object> parameters) { 203 if (numTrees < 1) { 204 throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees); 205 } 206 this.numTrees = numTrees; 207 this.parameters.putAll(parameters); 208 } 209 210 /** 211 * For olcut. 212 */ 213 protected XGBoostTrainer() { } 214 215 /** 216 * Used by the OLCUT configuration system, and should not be called by external code. 217 */ 218 @Override 219 public void postConfig() { 220 parameters.put("eta", eta); 221 parameters.put("gamma", gamma); 222 parameters.put("max_depth", maxDepth); 223 parameters.put("min_child_weight", minChildWeight); 224 parameters.put("subsample", subsample); 225 parameters.put("colsample_bytree", featureSubsample); 226 parameters.put("lambda", lambda); 227 parameters.put("alpha", alpha); 228 parameters.put("nthread", nThread); 229 parameters.put("seed", seed); 230 parameters.put("silent", silent); 231 parameters.put("booster", booster.paramName); 232 } 233 234 @Override 235 public String toString() { 236 StringBuilder buffer = new StringBuilder(); 237 238 buffer.append("XGBoostTrainer(numTrees="); 239 buffer.append(numTrees); 240 buffer.append(",parameters"); 241 buffer.append(parameters.toString()); 242 buffer.append(")"); 243 244 return buffer.toString(); 245 } 246 247 protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Booster> models, XGBoostOutputConverter<T> converter) { 248 return new XGBoostModel<>(name,provenance,featureIDMap,outputIDInfo,models,converter); 249 } 250 251 @Override 252 public int getInvocationCount() { 253 return trainInvocationCounter; 254 } 255 256 protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T,Float> responseExtractor) throws XGBoostError { 257 return convertExamples(examples.getData(), examples.getFeatureIDMap(), responseExtractor); 258 } 259 260 protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> examples) throws XGBoostError { 261 return convertExamples(examples.getData(), examples.getFeatureIDMap(), null); 262 } 263 264 protected static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws XGBoostError { 265 return convertExamples(examples, featureMap, null); 266 } 267 268 /** 269 * Converts an iterable of examples into a DMatrix. 270 * @param examples The examples to convert. 271 * @param featureMap The feature id map which supplies the indices. 272 * @param responseExtractor The extraction function for the output. 273 * @param <T> The type of the output. 274 * @return A DMatrixTuple. 275 * @throws XGBoostError If the native library failed to construct the DMatrix. 276 */ 277 protected static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws XGBoostError { 278 // headers = array of start points for a row 279 // indices = array of feature indices for all data 280 // data = array of feature values for all data 281 // SparseType = DMatrix.SparseType.CSR 282 //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError 283 // 284 // then call 285 //public void setLabel(float[] labels) throws XGBoostError 286 287 boolean labelled = responseExtractor != null; 288 ArrayList<Float> labelsList = new ArrayList<>(); 289 ArrayList<Float> dataList = new ArrayList<>(); 290 ArrayList<Long> headersList = new ArrayList<>(); 291 ArrayList<Integer> indicesList = new ArrayList<>(); 292 ArrayList<Float> weightsList = new ArrayList<>(); 293 ArrayList<Integer> numValidFeatures = new ArrayList<>(); 294 ArrayList<Example<T>> examplesList = new ArrayList<>(); 295 296 long rowHeader = 0; 297 headersList.add(rowHeader); 298 for (Example<T> e : examples) { 299 if (labelled) { 300 labelsList.add(responseExtractor.apply(e.getOutput())); 301 weightsList.add(e.getWeight()); 302 } 303 examplesList.add(e); 304 long newRowHeader = convertSingleExample(e,featureMap,dataList,indicesList,headersList,rowHeader); 305 numValidFeatures.add((int) (newRowHeader-rowHeader)); 306 rowHeader = newRowHeader; 307 } 308 309 float[] data = Util.toPrimitiveFloat(dataList); 310 int[] indices = Util.toPrimitiveInt(indicesList); 311 long[] headers = Util.toPrimitiveLong(headersList); 312 313 DMatrix dataMatrix = new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,featureMap.size()); 314 if (labelled) { 315 float[] labels = Util.toPrimitiveFloat(labelsList); 316 dataMatrix.setLabel(labels); 317 float[] weights = Util.toPrimitiveFloat(weightsList); 318 dataMatrix.setWeight(weights); 319 } 320 @SuppressWarnings("unchecked") // Generic array creation 321 Example<T>[] exampleArray = (Example<T>[])examplesList.toArray(new Example[0]); 322 return new DMatrixTuple<>(dataMatrix,Util.toPrimitiveInt(numValidFeatures),exampleArray); 323 } 324 325 protected static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws XGBoostError { 326 return convertExample(example,featureMap,null); 327 } 328 329 /** 330 * Converts an examples into a DMatrix. 331 * @param example The example to convert. 332 * @param featureMap The feature id map which supplies the indices. 333 * @param responseExtractor The extraction function for the output. 334 * @param <T> The type of the output. 335 * @return A DMatrixTuple. 336 * @throws XGBoostError If the native library failed to construct the DMatrix. 337 */ 338 protected static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws XGBoostError { 339 // headers = array of start points for a row 340 // indices = array of feature indices for all data 341 // data = array of feature values for all data 342 // SparseType = DMatrix.SparseType.CSR 343 //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError 344 // 345 // then call 346 //public void setLabel(float[] labels) throws XGBoostError 347 348 boolean labelled = responseExtractor != null; 349 ArrayList<Float> dataList = new ArrayList<>(); 350 ArrayList<Integer> indicesList = new ArrayList<>(); 351 ArrayList<Long> headersList = new ArrayList<>(); 352 headersList.add(0L); 353 354 long header = convertSingleExample(example,featureMap,dataList,indicesList,headersList,0); 355 356 float[] data = Util.toPrimitiveFloat(dataList); 357 int[] indices = Util.toPrimitiveInt(indicesList); 358 long[] headers = Util.toPrimitiveLong(headersList); 359 360 DMatrix dataMatrix = new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,featureMap.size()); 361 if (labelled) { 362 float[] labels = new float[1]; 363 labels[0] = responseExtractor.apply(example.getOutput()); 364 dataMatrix.setLabel(labels); 365 float[] weights = new float[1]; 366 weights[0] = example.getWeight(); 367 dataMatrix.setWeight(weights); 368 } 369 @SuppressWarnings("unchecked") // Generic array creation 370 Example<T>[] exampleArray = (Example<T>[])new Example[]{example}; 371 return new DMatrixTuple<>(dataMatrix,new int[]{(int)header},exampleArray); 372 } 373 374 /** 375 * Writes out the features from an example into the three supplied {@link ArrayList}s. 376 * <p> 377 * This is used to transform examples into the right format for an XGBoost call. 378 * It's used in both the Classification and Regression XGBoost backends. 379 * The ArrayLists must be non-null, and can contain existing values (as this 380 * method is called multiple times to build up an arraylist containing all the 381 * feature values for a dataset). 382 * <p> 383 * Features with colliding feature ids are summed together. 384 * <p> 385 * Can throw IllegalArgumentException if the {@link Example} contains no features. 386 * @param example The example to inspect. 387 * @param featureMap The feature map of the model/dataset (used to preserve hash information). 388 * @param dataList The output feature values. 389 * @param indicesList The output indices. 390 * @param headersList The output header position (an integer saying how long each sparse example is). 391 * @param header The current header position. 392 * @param <T> The type of the example. 393 * @return The updated header position. 394 */ 395 protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) { 396 int numActiveFeatures = 0; 397 int prevIdx = -1; 398 int indicesSize = indicesList.size(); 399 for (Feature f : example) { 400 int id = featureMap.getID(f.getName()); 401 if (id > prevIdx){ 402 prevIdx = id; 403 dataList.add((float) f.getValue()); 404 indicesList.add(id); 405 numActiveFeatures++; 406 } else if (id > -1) { 407 // 408 // Collision, deal with it. 409 int collisionIdx = Util.binarySearch(indicesList,id,indicesSize,numActiveFeatures+indicesSize); 410 if (collisionIdx < 0) { 411 // 412 // Collision but not present in tmpIndices 413 // move data and bump i 414 collisionIdx = - (collisionIdx + 1); 415 indicesList.add(collisionIdx,id); 416 dataList.add(collisionIdx,(float) f.getValue()); 417 numActiveFeatures++; 418 } else { 419 // 420 // Collision present in tmpIndices 421 // add the values. 422 dataList.set(collisionIdx, dataList.get(collisionIdx) + (float) f.getValue()); 423 } 424 } 425 } 426 if (numActiveFeatures == 0) { 427 throw new IllegalArgumentException("No features found in Example " + example.toString()); 428 } 429 header += numActiveFeatures; 430 headersList.add(header); 431 return header; 432 } 433 434 /** 435 * Writes out the features from a SparseVector into the three supplied {@link ArrayList}s. 436 * <p> 437 * This is used to transform examples into the right format for an XGBoost call. 438 * It's used when predicting with an externally trained XGBoost model, as the 439 * external training may not respect Tribuo's feature ordering constraints. 440 * The ArrayLists must be non-null, and can contain existing values (as this 441 * method is called multiple times to build up an arraylist containing all the 442 * feature values for a dataset). 443 * </p> 444 * <p> 445 * This is much simpler than {@link XGBoostTrainer#convertSingleExample} as the validation 446 * of feature indices is done in the {@link org.tribuo.interop.ExternalModel} class. 447 * </p> 448 * @param vector The features to convert. 449 * @param dataList The output feature values. 450 * @param indicesList The output indices. 451 * @param headersList The output header position (an integer saying how long each sparse example is). 452 * @param header The current header position. 453 * @return The updated header position. 454 */ 455 static long convertSingleExample(SparseVector vector, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) { 456 int numActiveFeatures = 0; 457 for (VectorTuple v : vector) { 458 dataList.add((float) v.value); 459 indicesList.add(v.index); 460 numActiveFeatures++; 461 } 462 header += numActiveFeatures; 463 headersList.add(header); 464 return header; 465 } 466 467 /** 468 * Used when predicting with an externally trained XGBoost model. 469 * @param vector The features to convert. 470 * @return A DMatrix representing the features. 471 * @throws XGBoostError If the native library returns an error state. 472 */ 473 protected static DMatrix convertSparseVector(SparseVector vector) throws XGBoostError { 474 // headers = array of start points for a row 475 // indices = array of feature indices for all data 476 // data = array of feature values for all data 477 // SparseType = DMatrix.SparseType.CSR 478 //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError 479 ArrayList<Float> dataList = new ArrayList<>(); 480 ArrayList<Long> headersList = new ArrayList<>(); 481 ArrayList<Integer> indicesList = new ArrayList<>(); 482 483 long rowHeader = 0; 484 headersList.add(rowHeader); 485 convertSingleExample(vector,dataList,indicesList,headersList,rowHeader); 486 487 float[] data = Util.toPrimitiveFloat(dataList); 488 int[] indices = Util.toPrimitiveInt(indicesList); 489 long[] headers = Util.toPrimitiveLong(headersList); 490 491 return new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,vector.size()); 492 } 493 494 /** 495 * Used when predicting with an externally trained XGBoost model. 496 * <p> 497 * It is assumed all vectors are the same size when passed into this function. 498 * @param vectors The batch of features to convert. 499 * @return A DMatrix representing the batch of features. 500 * @throws XGBoostError If the native library returns an error state. 501 */ 502 protected static DMatrix convertSparseVectors(List<SparseVector> vectors) throws XGBoostError { 503 // headers = array of start points for a row 504 // indices = array of feature indices for all data 505 // data = array of feature values for all data 506 // SparseType = DMatrix.SparseType.CSR 507 //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError 508 ArrayList<Float> dataList = new ArrayList<>(); 509 ArrayList<Long> headersList = new ArrayList<>(); 510 ArrayList<Integer> indicesList = new ArrayList<>(); 511 512 int numFeatures = 0; 513 long rowHeader = 0; 514 headersList.add(rowHeader); 515 for (SparseVector e : vectors) { 516 rowHeader = convertSingleExample(e,dataList,indicesList,headersList,rowHeader); 517 numFeatures = e.size(); // All vectors are assumed to be the same size. 518 } 519 520 float[] data = Util.toPrimitiveFloat(dataList); 521 int[] indices = Util.toPrimitiveInt(indicesList); 522 long[] headers = Util.toPrimitiveLong(headersList); 523 524 return new DMatrix(headers, indices, data, DMatrix.SparseType.CSR, numFeatures); 525 } 526 527 /** 528 * Tuple of a DMatrix, the number of valid features in each example, and the examples themselves. 529 * <p> 530 * One day it'll be a record. 531 * @param <T> The output type. 532 */ 533 protected static class DMatrixTuple<T extends Output<T>> { 534 public final DMatrix data; 535 public final int[] numValidFeatures; 536 public final Example<T>[] examples; 537 538 public DMatrixTuple(DMatrix data, int[] numValidFeatures, Example<T>[] examples) { 539 this.data = data; 540 this.numValidFeatures = numValidFeatures; 541 this.examples = examples; 542 } 543 } 544 545 /** 546 * Provenance for {@link XGBoostTrainer}. No longer used. 547 */ 548 @Deprecated 549 protected static class XGBoostTrainerProvenance extends SkeletalTrainerProvenance { 550 private static final long serialVersionUID = 1L; 551 552 protected <T extends Output<T>> XGBoostTrainerProvenance(XGBoostTrainer<T> host) { 553 super(host); 554 } 555 556 protected XGBoostTrainerProvenance(Map<String,Provenance> map) { 557 super(map); 558 } 559 } 560}