Class XGBoostTrainer<T extends Output<T>>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
XGBoostClassificationTrainer
,XGBoostRegressionTrainer
Trainer
which wraps the XGBoost training procedure.
This only exposes a few of XGBoost's training parameters.
It uses pthreads outside of the JVM to parallelise the computation.
See:
Chen T, Guestrin C. "XGBoost: A Scalable Tree Boosting System" Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.and for the original algorithm:
Friedman JH. "Greedy Function Approximation: a Gradient Boosting Machine" Annals of statistics, 2001.N.B.: This uses a native C implementation of xgboost that links to various C libraries, including libgomp and glibc. If you're running on Alpine, which does not natively use glibc, you'll need to install glibc into the container. On Windows this binary is not available in the Maven Central release, you'll need to compile it from source.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic enum
The type of XGBoost model.protected static class
XGBoostTrainer.DMatrixTuple<T extends Output<T>>
Tuple of a DMatrix, the number of valid features in each example, and the examples themselves.protected static class
Deprecated. -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected int
protected int
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsModifierConstructorDescriptionprotected
For olcut.protected
XGBoostTrainer
(int numTrees) protected
XGBoostTrainer
(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.protected
XGBoostTrainer
(int numTrees, int numThreads, boolean silent) protected
XGBoostTrainer
(int numTrees, Map<String, Object> parameters) This gives direct access to the XGBoost parameter map. -
Method Summary
Modifier and TypeMethodDescriptionprotected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertDataset
(Dataset<T> examples) protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertDataset
(Dataset<T> examples, Function<T, Float> responseExtractor) protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertExample
(Example<T> example, ImmutableFeatureMap featureMap) protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertExample
(Example<T> example, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) Converts an examples into a DMatrix.protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertExamples
(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) protected static <T extends Output<T>>
XGBoostTrainer.DMatrixTuple<T> convertExamples
(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) Converts an iterable of examples into a DMatrix.protected static <T extends Output<T>>
longconvertSingleExample
(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) Writes out the features from an example into the three suppliedArrayList
s.protected static ml.dmlc.xgboost4j.java.DMatrix
convertSparseVector
(SparseVector vector) Used when predicting with an externally trained XGBoost model.protected static ml.dmlc.xgboost4j.java.DMatrix
convertSparseVectors
(List<SparseVector> vectors) Used when predicting with an externally trained XGBoost model.protected XGBoostModel
<T> createModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter) int
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Field Details
-
parameters
-
numTrees
-
trainInvocationCounter
-
-
Constructor Details
-
XGBoostTrainer
-
XGBoostTrainer
-
XGBoostTrainer
protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.- Parameters:
numTrees
- Number of trees to boost.eta
- Step size shrinkage parameter (default 0.3, range [0,1]).gamma
- Minimum loss reduction to make a split (default 0, range [0,inf]).maxDepth
- Maximum tree depth (default 6, range [1,inf]).minChildWeight
- Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).subsample
- Subsample size for each tree (default 1, range (0,1]).featureSubsample
- Subsample features for each tree (default 1, range (0,1]).lambda
- L2 regularization term on weights (default 1).alpha
- L1 regularization term on weights (default 0).nThread
- Number of threads to use (default 4).silent
- Silence the training output text.seed
- RNG seed.
-
XGBoostTrainer
This gives direct access to the XGBoost parameter map.It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
- Parameters:
numTrees
- Number of trees to boost.parameters
- A map from string to object, where object can be Number or String.
-
XGBoostTrainer
protected XGBoostTrainer()For olcut.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
toString
-
createModel
protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter) -
getInvocationCount
Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
convertDataset
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
-
convertDataset
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples) throws ml.dmlc.xgboost4j.java.XGBoostError - Throws:
ml.dmlc.xgboost4j.java.XGBoostError
-
convertExamples
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError - Throws:
ml.dmlc.xgboost4j.java.XGBoostError
-
convertExamples
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostErrorConverts an iterable of examples into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
examples
- The examples to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertExample
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError - Throws:
ml.dmlc.xgboost4j.java.XGBoostError
-
convertExample
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T, Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostErrorConverts an examples into a DMatrix.- Type Parameters:
T
- The type of the output.- Parameters:
example
- The example to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.- Returns:
- A DMatrixTuple.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.
-
convertSingleExample
protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) Writes out the features from an example into the three suppliedArrayList
s.This is used to transform examples into the right format for an XGBoost call. It's used in both the Classification and Regression XGBoost backends. The ArrayLists must be non-null, and can contain existing values (as this method is called multiple times to build up an arraylist containing all the feature values for a dataset).
Features with colliding feature ids are summed together.
Can throw IllegalArgumentException if the
Example
contains no features.- Type Parameters:
T
- The type of the example.- Parameters:
example
- The example to inspect.featureMap
- The feature map of the model/dataset (used to preserve hash information).dataList
- The output feature values.indicesList
- The output indices.headersList
- The output header position (an integer saying how long each sparse example is).header
- The current header position.- Returns:
- The updated header position.
-
convertSparseVector
protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVector(SparseVector vector) throws ml.dmlc.xgboost4j.java.XGBoostError Used when predicting with an externally trained XGBoost model.- Parameters:
vector
- The features to convert.- Returns:
- A DMatrix representing the features.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.
-
convertSparseVectors
protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVectors(List<SparseVector> vectors) throws ml.dmlc.xgboost4j.java.XGBoostError Used when predicting with an externally trained XGBoost model.It is assumed all vectors are the same size when passed into this function.
- Parameters:
vectors
- The batch of features to convert.- Returns:
- A DMatrix representing the batch of features.
- Throws:
ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.
-