Class XGBoostTrainer<T extends Output<T>>

java.lang.Object
org.tribuo.common.xgboost.XGBoostTrainer<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
XGBoostClassificationTrainer, XGBoostRegressionTrainer

public abstract class XGBoostTrainer<T extends Output<T>> extends Object implements Trainer<T>, WeightedExamples
A Trainer which wraps the XGBoost training procedure.

This only exposes a few of XGBoost's training parameters.

It uses pthreads outside of the JVM to parallelise the computation.

See:

 Chen T, Guestrin C.
 "XGBoost: A Scalable Tree Boosting System"
 Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
 
and for the original algorithm:
 Friedman JH.
 "Greedy Function Approximation: a Gradient Boosting Machine"
 Annals of statistics, 2001.
 
N.B.: This uses a native C implementation of xgboost that links to various C libraries, including libgomp and glibc. If you're running on Alpine, which does not natively use glibc, you'll need to install glibc into the container. On Windows this binary is not available in the Maven Central release, you'll need to compile it from source.
  • Field Details

  • Constructor Details

    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees)
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, int numThreads, boolean silent)
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed)
      Create an XGBoost trainer.
      Parameters:
      numTrees - Number of trees to boost.
      eta - Step size shrinkage parameter (default 0.3, range [0,1]).
      gamma - Minimum loss reduction to make a split (default 0, range [0,inf]).
      maxDepth - Maximum tree depth (default 6, range [1,inf]).
      minChildWeight - Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).
      subsample - Subsample size for each tree (default 1, range (0,1]).
      featureSubsample - Subsample features for each tree (default 1, range (0,1]).
      lambda - L2 regularization term on weights (default 1).
      alpha - L1 regularization term on weights (default 0).
      nThread - Number of threads to use (default 4).
      silent - Silence the training output text.
      seed - RNG seed.
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, Map<String,Object> parameters)
      This gives direct access to the XGBoost parameter map.

      It lets you pick things that we haven't exposed like dropout trees, binary classification etc.

      This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.

      Parameters:
      numTrees - Number of trees to boost.
      parameters - A map from string to object, where object can be Number or String.
    • XGBoostTrainer

      protected XGBoostTrainer()
      For olcut.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • createModel

      protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter)
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<T extends Output<T>>
      Returns:
      The number of train invocations.
    • convertDataset

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError
    • convertDataset

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples) throws ml.dmlc.xgboost4j.java.XGBoostError
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError
    • convertExamples

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError
    • convertExamples

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an iterable of examples into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      examples - The examples to convert.
      featureMap - The feature id map which supplies the indices.
      responseExtractor - The extraction function for the output.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertExample

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError
    • convertExample

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an examples into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      example - The example to convert.
      featureMap - The feature id map which supplies the indices.
      responseExtractor - The extraction function for the output.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertSingleExample

      protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header)
      Writes out the features from an example into the three supplied ArrayLists.

      This is used to transform examples into the right format for an XGBoost call. It's used in both the Classification and Regression XGBoost backends. The ArrayLists must be non-null, and can contain existing values (as this method is called multiple times to build up an arraylist containing all the feature values for a dataset).

      Features with colliding feature ids are summed together.

      Can throw IllegalArgumentException if the Example contains no features.

      Type Parameters:
      T - The type of the example.
      Parameters:
      example - The example to inspect.
      featureMap - The feature map of the model/dataset (used to preserve hash information).
      dataList - The output feature values.
      indicesList - The output indices.
      headersList - The output header position (an integer saying how long each sparse example is).
      header - The current header position.
      Returns:
      The updated header position.
    • convertSparseVector

      protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVector(SparseVector vector) throws ml.dmlc.xgboost4j.java.XGBoostError
      Used when predicting with an externally trained XGBoost model.
      Parameters:
      vector - The features to convert.
      Returns:
      A DMatrix representing the features.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library returns an error state.
    • convertSparseVectors

      protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVectors(List<SparseVector> vectors) throws ml.dmlc.xgboost4j.java.XGBoostError
      Used when predicting with an externally trained XGBoost model.

      It is assumed all vectors are the same size when passed into this function.

      Parameters:
      vectors - The batch of features to convert.
      Returns:
      A DMatrix representing the batch of features.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library returns an error state.