Class XGBoostTrainer<T extends Output<T>>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
XGBoostClassificationTrainer
,XGBoostRegressionTrainer
Trainer
which wraps the XGBoost training procedure.
This only exposes a few of XGBoost's training parameters.
It uses pthreads outside of the JVM to parallelise the computation.
See:
Chen T, Guestrin C. "XGBoost: A Scalable Tree Boosting System" Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.and for the original algorithm:
Friedman JH. "Greedy Function Approximation: a Gradient Boosting Machine" Annals of statistics, 2001.N.B.: XGBoost4J wraps the native C implementation of xgboost that links to various C libraries, including libgomp and glibc (on Linux). If you're running on Alpine, which does not natively use glibc, you'll need to install glibc into the container. On the macOS binary on Maven Central is compiled without OpenMP support, meaning that XGBoost is single threaded on macOS. You can recompile the macOS binary with OpenMP support after installing libomp from homebrew if necessary.
-
Nested Class Summary
Modifier and TypeClassDescriptionstatic enum
The type of XGBoost model.protected static class
XGBoostTrainer.DMatrixTuple<T extends Output<T>>
Tuple of a DMatrix, the number of valid features in each example, and the examples themselves.static enum
The logging verbosity of the native library.static enum
The tree building algorithm.protected static class
Deprecated.Unused. -
Field Summary
Modifier and TypeFieldDescriptionprotected int
The number of trees to build.Override for the parameter map, must contain all parameters, including the objective function.The XGBoost parameter map, only accessed internally.protected int
Number of times thetrain
method has been called on this object.Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For olcut.protected
XGBoostTrainer
(int numTrees) Constructs an XGBoost trainer using the specified number of trees.protected
XGBoostTrainer
(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.protected
XGBoostTrainer
(int numTrees, int numThreads, boolean silent) Constructs an XGBoost trainer using the specified number of trees.protected
XGBoostTrainer
(int numTrees, Map<String, Object> parameters) This gives direct access to the XGBoost parameter map.protected
XGBoostTrainer
(XGBoostTrainer.BoosterType boosterType, XGBoostTrainer.TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, XGBoostTrainer.LoggingVerbosity verbosity, long seed) Create an XGBoost trainer. -
Method Summary
Modifier and TypeMethodDescriptionprotected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertDataset
(Dataset<T> examples) Converts a dataset into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertDataset
(Dataset<T> examples, Function<T, Float> responseExtractor) Converts a dataset into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertExample
(Example<T> example, ImmutableFeatureMap featureMap) Converts an example into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertExample
(Example<T> example, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) Converts an example into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertExamples
(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) Converts an iterable of examples into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T>convertExamples
(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) Converts an iterable of examples into a DMatrix.protected static <T extends Output<T>>
longconvertSingleExample
(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) Writes out the features from an example into the three suppliedArrayList
s.protected static ml.dmlc.xgboost4j.java.DMatrix
convertSparseVector
(SparseVector vector) Used when predicting with an externally trained XGBoost model.protected static ml.dmlc.xgboost4j.java.DMatrix
convertSparseVectors
(List<SparseVector> vectors) Used when predicting with an externally trained XGBoost model.copyParams
(Map<String, ?> input) Returns a copy of the supplied parameter map which has the appropriate type for passing to XGBoost.train.protected XGBoostModel<T>
createModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter) Creates an XGBoost model from the booster list.int
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.toString()
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Field Details
-
parameters
The XGBoost parameter map, only accessed internally. -
overrideParameters
@Config(description="Override for parameters, if used must contain all the relevant parameters, including the objective") protected Map<String,String> overrideParametersOverride for the parameter map, must contain all parameters, including the objective function. -
numTrees
@Config(mandatory=true, description="The number of trees to build.") protected int numTreesThe number of trees to build. -
trainInvocationCounter
protected int trainInvocationCounterNumber of times thetrain
method has been called on this object.
-
-
Constructor Details
-
XGBoostTrainer
protected XGBoostTrainer(int numTrees) Constructs an XGBoost trainer using the specified number of trees.- Parameters:
numTrees
- The number of trees.
-
XGBoostTrainer
protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) Constructs an XGBoost trainer using the specified number of trees.- Parameters:
numTrees
- The number of trees.numThreads
- The number of training threads.silent
- Should the logging be silenced?
-
XGBoostTrainer
protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.Sets the boosting algorithm to
XGBoostTrainer.BoosterType.GBTREE
and the tree building algorithm toXGBoostTrainer.TreeMethod.AUTO
.- Parameters:
numTrees
- Number of trees to boost.eta
- Step size shrinkage parameter (default 0.3, range [0,1]).gamma
- Minimum loss reduction to make a split (default 0, range [0,inf]).maxDepth
- Maximum tree depth (default 6, range [1,inf]).minChildWeight
- Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).subsample
- Subsample size for each tree (default 1, range (0,1]).featureSubsample
- Subsample features for each tree (default 1, range (0,1]).lambda
- L2 regularization term on weights (default 1).alpha
- L1 regularization term on weights (default 0).nThread
- Number of threads to use (default 4).silent
- Silence the training output text.seed
- RNG seed.
-
XGBoostTrainer
protected XGBoostTrainer(XGBoostTrainer.BoosterType boosterType, XGBoostTrainer.TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, XGBoostTrainer.LoggingVerbosity verbosity, long seed) Create an XGBoost trainer.- Parameters:
boosterType
- The base learning algorithm.treeMethod
- The tree building algorithm if using a tree booster.numTrees
- Number of trees to boost.eta
- Step size shrinkage parameter (default 0.3, range [0,1]).gamma
- Minimum loss reduction to make a split (default 0, range [0,inf]).maxDepth
- Maximum tree depth (default 6, range [1,inf]).minChildWeight
- Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).subsample
- Subsample size for each tree (default 1, range (0,1]).featureSubsample
- Subsample features for each tree (default 1, range (0,1]).lambda
- L2 regularization term on weights (default 1).alpha
- L1 regularization term on weights (default 0).nThread
- Number of threads to use (default 4).verbosity
- Set the logging verbosity of the native library.seed
- RNG seed.
-
XGBoostTrainer
This gives direct access to the XGBoost parameter map.It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
- Parameters:
numTrees
- Number of trees to boost.parameters
- A map from string to object, where object can be Number or String.
-
XGBoostTrainer
protected XGBoostTrainer()For olcut.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
toString
-
createModel
protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter) Creates an XGBoost model from the booster list.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.models
- The boosters.converter
- The converter from XGBoost's output to Tribuo predictions.- Returns:
- An XGBoost model.
-
copyParams
Returns a copy of the supplied parameter map which has the appropriate type for passing to XGBoost.train.- Parameters:
input
- The parameter map.- Returns:
- A (shallow) copy of the supplied map.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
setInvocationCount
public void setInvocationCount(int invocationCount) Description copied from interface:Trainer
Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Specified by:
setInvocationCount
in interfaceTrainer<T extends Output<T>>
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-
convertDataset
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostErrorConverts a dataset into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
examples
- The examples to convert.responseExtractor
- The extraction function for the output.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertDataset
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples) throws ml.dmlc.xgboost4j.java.XGBoostError Converts a dataset into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
examples
- The examples to convert.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertExamples
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError Converts an iterable of examples into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
examples
- The examples to convert.featureMap
- The feature id map which supplies the indices.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertExamples
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostErrorConverts an iterable of examples into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
examples
- The examples to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertExample
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError Converts an example into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
example
- The example to convert.featureMap
- The feature id map which supplies the indices.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertExample
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostErrorConverts an example into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
example
- The example to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertSingleExample
protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) Writes out the features from an example into the three suppliedArrayList
s.This is used to transform examples into the right format for an XGBoost call. It's used in both the Classification and Regression XGBoost backends. The ArrayLists must be non-null, and can contain existing values (as this method is called multiple times to build up an arraylist containing all the feature values for a dataset).
Features with colliding feature ids are summed together.
Can throw IllegalArgumentException if the
Example
contains no features.- Type Parameters:
T
- The type of the example.- Parameters:
example
- The example to inspect.featureMap
- The feature map of the model/dataset (used to preserve hash information).dataList
- The output feature values.indicesList
- The output indices.headersList
- The output header position (an integer saying how long each sparse example is).header
- The current header position.- Returns:
- The updated header position.
-
convertSparseVector
protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVector(SparseVector vector) throws ml.dmlc.xgboost4j.java.XGBoostError Used when predicting with an externally trained XGBoost model.- Parameters:
vector
- The features to convert.- Returns:
- A DMatrix representing the features.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.
-
convertSparseVectors
protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVectors(List<SparseVector> vectors) throws ml.dmlc.xgboost4j.java.XGBoostError Used when predicting with an externally trained XGBoost model.It is assumed all vectors are the same size when passed into this function.
- Parameters:
vectors
- The batch of features to convert.- Returns:
- A DMatrix representing the batch of features.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.
-