001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Trainer; 025import org.tribuo.classification.Label; 026import org.tribuo.common.xgboost.XGBoostModel; 027import org.tribuo.common.xgboost.XGBoostTrainer; 028import org.tribuo.provenance.ModelProvenance; 029import org.tribuo.provenance.TrainerProvenance; 030import org.tribuo.provenance.impl.TrainerProvenanceImpl; 031import ml.dmlc.xgboost4j.java.Booster; 032import ml.dmlc.xgboost4j.java.XGBoost; 033import ml.dmlc.xgboost4j.java.XGBoostError; 034 035import java.time.OffsetDateTime; 036import java.util.Collections; 037import java.util.Map; 038import java.util.function.Function; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * A {@link Trainer} which wraps the XGBoost training procedure. 044 * <p> 045 * This only exposes a few of XGBoost's training parameters. 046 * <p> 047 * It uses pthreads outside of the JVM to parallelise the computation. 048 * <p> 049 * See: 050 * <pre> 051 * Chen T, Guestrin C. 052 * "XGBoost: A Scalable Tree Boosting System" 053 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016. 054 * </pre> 055 * and for the original algorithm: 056 * <pre> 057 * Friedman JH. 058 * "Greedy Function Approximation: a Gradient Boosting Machine" 059 * Annals of statistics, 2001. 060 * </pre> 061 * <p> 062 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), 063 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary 064 * on Maven Central. 065 */ 066public final class XGBoostClassificationTrainer extends XGBoostTrainer<Label> { 067 068 private static final Logger logger = Logger.getLogger(XGBoostClassificationTrainer.class.getName()); 069 070 @Config(description="Evaluation metric to use. The default value is set based on the objective function, so this can be usually left blank.") 071 private String evalMetric = ""; 072 073 public XGBoostClassificationTrainer(int numTrees) { 074 this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED); 075 } 076 077 public XGBoostClassificationTrainer(int numTrees, int numThreads, boolean silent) { 078 this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED); 079 } 080 081 /** 082 * Create an XGBoost trainer. 083 * 084 * @param numTrees Number of trees to boost. 085 * @param eta Step size shrinkage parameter (default 0.3, range [0,1]). 086 * @param gamma Minimum loss reduction to make a split (default 0, range 087 * [0,inf]). 088 * @param maxDepth Maximum tree depth (default 6, range [1,inf]). 089 * @param minChildWeight Minimum sum of instance weights needed in a leaf 090 * (default 1, range [0, inf]). 091 * @param subsample Subsample size for each tree (default 1, range (0,1]). 092 * @param featureSubsample Subsample features for each tree (default 1, 093 * range (0,1]). 094 * @param lambda L2 regularization term on weights (default 1). 095 * @param alpha L1 regularization term on weights (default 0). 096 * @param nThread Number of threads to use (default 4). 097 * @param silent Silence the training output text. 098 * @param seed RNG seed. 099 */ 100 public XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) { 101 super(numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent,seed); 102 postConfig(); 103 } 104 105 /** 106 * This gives direct access to the XGBoost parameter map. 107 * <p> 108 * It lets you pick things that we haven't exposed like dropout trees, binary classification etc. 109 * <p> 110 * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results. 111 * @param numTrees Number of trees to boost. 112 * @param parameters A map from string to object, where object can be Number or String. 113 */ 114 public XGBoostClassificationTrainer(int numTrees, Map<String,Object> parameters) { 115 super(numTrees,parameters); 116 postConfig(); 117 } 118 119 /** 120 * For olcut. 121 */ 122 protected XGBoostClassificationTrainer() { } 123 124 /** 125 * Used by the OLCUT configuration system, and should not be called by external code. 126 */ 127 @Override 128 public void postConfig() { 129 super.postConfig(); 130 parameters.put("objective", "multi:softprob"); 131 if(!evalMetric.isEmpty()) { 132 parameters.put("eval_metric", evalMetric); 133 } 134 } 135 136 @Override 137 public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 138 if (examples.getOutputInfo().getUnknownCount() > 0) { 139 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 140 } 141 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 142 ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo(); 143 TrainerProvenance trainerProvenance = getProvenance(); 144 trainInvocationCounter++; 145 parameters.put("num_class", outputInfo.size()); 146 Booster model; 147 Function<Label,Float> responseExtractor = (Label l) -> (float) outputInfo.getID(l); 148 try { 149 DMatrixTuple<Label> trainingData = convertExamples(examples, featureMap, responseExtractor); 150 model = XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null); 151 } catch (XGBoostError e) { 152 logger.log(Level.SEVERE, "XGBoost threw an error", e); 153 throw new IllegalStateException(e); 154 } 155 156 ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 157 XGBoostModel<Label> xgModel = createModel("xgboost-classification-model", provenance, featureMap, outputInfo, Collections.singletonList(model), new XGBoostClassificationConverter()); 158 159 return xgModel; 160 } 161 162 @Override 163 public TrainerProvenance getProvenance() { 164 return new TrainerProvenanceImpl(this); 165 } 166}