public abstract class XGBoostTrainer<T extends Output<T>> extends Object implements Trainer<T>, WeightedExamples
Trainer
which wraps the XGBoost training procedure.
This only exposes a few of XGBoost's training parameters.
It uses pthreads outside of the JVM to parallelise the computation.
See:
Chen T, Guestrin C. "XGBoost: A Scalable Tree Boosting System" Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.and for the original algorithm:
Friedman JH. "Greedy Function Approximation: a Gradient Boosting Machine" Annals of statistics, 2001.N.B.: This uses a native C implementation of xgboost that links to various C libraries, including libgomp and glibc. If you're running on Alpine, which does not natively use glibc, you'll need to install glibc into the container. On Windows this binary is not available in the Maven Central release, you'll need to compile it from source.
Modifier and Type | Class and Description |
---|---|
static class |
XGBoostTrainer.BoosterType
The type of XGBoost model.
|
protected static class |
XGBoostTrainer.DMatrixTuple<T extends Output<T>>
Tuple of a DMatrix, the number of valid features in each example, and the examples themselves.
|
protected static class |
XGBoostTrainer.XGBoostTrainerProvenance
Deprecated.
|
Modifier and Type | Field and Description |
---|---|
protected int |
numTrees |
protected Map<String,Object> |
parameters |
protected int |
trainInvocationCounter |
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
XGBoostTrainer()
For olcut.
|
protected |
XGBoostTrainer(int numTrees) |
protected |
XGBoostTrainer(int numTrees,
double eta,
double gamma,
int maxDepth,
double minChildWeight,
double subsample,
double featureSubsample,
double lambda,
double alpha,
int nThread,
boolean silent,
long seed)
Create an XGBoost trainer.
|
protected |
XGBoostTrainer(int numTrees,
int numThreads,
boolean silent) |
protected |
XGBoostTrainer(int numTrees,
Map<String,Object> parameters)
This gives direct access to the XGBoost parameter map.
|
Modifier and Type | Method and Description |
---|---|
protected static <T extends Output<T>> |
convertDataset(Dataset<T> examples) |
protected static <T extends Output<T>> |
convertDataset(Dataset<T> examples,
Function<T,Float> responseExtractor) |
protected static <T extends Output<T>> |
convertExample(Example<T> example,
ImmutableFeatureMap featureMap) |
protected static <T extends Output<T>> |
convertExample(Example<T> example,
ImmutableFeatureMap featureMap,
Function<T,Float> responseExtractor)
Converts an examples into a DMatrix.
|
protected static <T extends Output<T>> |
convertExamples(Iterable<Example<T>> examples,
ImmutableFeatureMap featureMap) |
protected static <T extends Output<T>> |
convertExamples(Iterable<Example<T>> examples,
ImmutableFeatureMap featureMap,
Function<T,Float> responseExtractor)
Converts an iterable of examples into a DMatrix.
|
protected static <T extends Output<T>> |
convertSingleExample(Example<T> example,
ImmutableFeatureMap featureMap,
ArrayList<Float> dataList,
ArrayList<Integer> indicesList,
ArrayList<Long> headersList,
long header)
Writes out the features from an example into the three supplied
ArrayList s. |
protected static ml.dmlc.xgboost4j.java.DMatrix |
convertSparseVector(SparseVector vector)
Used when predicting with an externally trained XGBoost model.
|
protected static ml.dmlc.xgboost4j.java.DMatrix |
convertSparseVectors(List<SparseVector> vectors)
Used when predicting with an externally trained XGBoost model.
|
protected XGBoostModel<T> |
createModel(String name,
ModelProvenance provenance,
ImmutableFeatureMap featureIDMap,
ImmutableOutputInfo<T> outputIDInfo,
List<ml.dmlc.xgboost4j.java.Booster> models,
XGBoostOutputConverter<T> converter) |
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
@Config(mandatory=true, description="The number of trees to build.") protected int numTrees
protected int trainInvocationCounter
protected XGBoostTrainer(int numTrees)
protected XGBoostTrainer(int numTrees, int numThreads, boolean silent)
protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed)
numTrees
- Number of trees to boost.eta
- Step size shrinkage parameter (default 0.3, range [0,1]).gamma
- Minimum loss reduction to make a split (default 0, range
[0,inf]).maxDepth
- Maximum tree depth (default 6, range [1,inf]).minChildWeight
- Minimum sum of instance weights needed in a leaf
(default 1, range [0, inf]).subsample
- Subsample size for each tree (default 1, range (0,1]).featureSubsample
- Subsample features for each tree (default 1,
range (0,1]).lambda
- L2 regularization term on weights (default 1).alpha
- L1 regularization term on weights (default 0).nThread
- Number of threads to use (default 4).silent
- Silence the training output text.seed
- RNG seed.protected XGBoostTrainer(int numTrees, Map<String,Object> parameters)
It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
numTrees
- Number of trees to boost.parameters
- A map from string to object, where object can be Number or String.protected XGBoostTrainer()
public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter)
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
ml.dmlc.xgboost4j.java.XGBoostError
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples) throws ml.dmlc.xgboost4j.java.XGBoostError
ml.dmlc.xgboost4j.java.XGBoostError
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
ml.dmlc.xgboost4j.java.XGBoostError
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
T
- The type of the output.examples
- The examples to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
ml.dmlc.xgboost4j.java.XGBoostError
protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
T
- The type of the output.example
- The example to convert.featureMap
- The feature id map which supplies the indices.responseExtractor
- The extraction function for the output.ml.dmlc.xgboost4j.java.XGBoostError
- If the native library failed to construct the DMatrix.protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header)
ArrayList
s.
This is used to transform examples into the right format for an XGBoost call. It's used in both the Classification and Regression XGBoost backends. The ArrayLists must be non-null, and can contain existing values (as this method is called multiple times to build up an arraylist containing all the feature values for a dataset).
Features with colliding feature ids are summed together.
Can throw IllegalArgumentException if the Example
contains no features.
T
- The type of the example.example
- The example to inspect.featureMap
- The feature map of the model/dataset (used to preserve hash information).dataList
- The output feature values.indicesList
- The output indices.headersList
- The output header position (an integer saying how long each sparse example is).header
- The current header position.protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVector(SparseVector vector) throws ml.dmlc.xgboost4j.java.XGBoostError
vector
- The features to convert.ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVectors(List<SparseVector> vectors) throws ml.dmlc.xgboost4j.java.XGBoostError
It is assumed all vectors are the same size when passed into this function.
vectors
- The batch of features to convert.ml.dmlc.xgboost4j.java.XGBoostError
- If the native library returns an error state.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.