Class XGBoostTrainer<T extends Output<T>>

java.lang.Object
org.tribuo.common.xgboost.XGBoostTrainer<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
XGBoostClassificationTrainer, XGBoostRegressionTrainer

public abstract class XGBoostTrainer<T extends Output<T>> extends Object implements Trainer<T>, WeightedExamples
A Trainer which wraps the XGBoost training procedure.

This only exposes a few of XGBoost's training parameters.

It uses pthreads outside of the JVM to parallelise the computation.

See:

 Chen T, Guestrin C.
 "XGBoost: A Scalable Tree Boosting System"
 Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
 
and for the original algorithm:
 Friedman JH.
 "Greedy Function Approximation: a Gradient Boosting Machine"
 Annals of statistics, 2001.
 
N.B.: XGBoost4J wraps the native C implementation of xgboost that links to various C libraries, including libgomp and glibc (on Linux). If you're running on Alpine, which does not natively use glibc, you'll need to install glibc into the container. On the macOS binary on Maven Central is compiled without OpenMP support, meaning that XGBoost is single threaded on macOS. You can recompile the macOS binary with OpenMP support after installing libomp from homebrew if necessary.
  • Field Details

    • parameters

      protected final Map<String,Object> parameters
      The XGBoost parameter map, only accessed internally.
    • overrideParameters

      @Config(description="Override for parameters, if used must contain all the relevant parameters, including the objective") protected Map<String,String> overrideParameters
      Override for the parameter map, must contain all parameters, including the objective function.
    • numTrees

      @Config(mandatory=true, description="The number of trees to build.") protected int numTrees
      The number of trees to build.
    • trainInvocationCounter

      protected int trainInvocationCounter
      Number of times the train method has been called on this object.
  • Constructor Details

    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees)
      Constructs an XGBoost trainer using the specified number of trees.
      Parameters:
      numTrees - The number of trees.
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, int numThreads, boolean silent)
      Constructs an XGBoost trainer using the specified number of trees.
      Parameters:
      numTrees - The number of trees.
      numThreads - The number of training threads.
      silent - Should the logging be silenced?
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed)
      Create an XGBoost trainer.

      Sets the boosting algorithm to XGBoostTrainer.BoosterType.GBTREE and the tree building algorithm to XGBoostTrainer.TreeMethod.AUTO.

      Parameters:
      numTrees - Number of trees to boost.
      eta - Step size shrinkage parameter (default 0.3, range [0,1]).
      gamma - Minimum loss reduction to make a split (default 0, range [0,inf]).
      maxDepth - Maximum tree depth (default 6, range [1,inf]).
      minChildWeight - Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).
      subsample - Subsample size for each tree (default 1, range (0,1]).
      featureSubsample - Subsample features for each tree (default 1, range (0,1]).
      lambda - L2 regularization term on weights (default 1).
      alpha - L1 regularization term on weights (default 0).
      nThread - Number of threads to use (default 4).
      silent - Silence the training output text.
      seed - RNG seed.
    • XGBoostTrainer

      protected XGBoostTrainer(XGBoostTrainer.BoosterType boosterType, XGBoostTrainer.TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, XGBoostTrainer.LoggingVerbosity verbosity, long seed)
      Create an XGBoost trainer.
      Parameters:
      boosterType - The base learning algorithm.
      treeMethod - The tree building algorithm if using a tree booster.
      numTrees - Number of trees to boost.
      eta - Step size shrinkage parameter (default 0.3, range [0,1]).
      gamma - Minimum loss reduction to make a split (default 0, range [0,inf]).
      maxDepth - Maximum tree depth (default 6, range [1,inf]).
      minChildWeight - Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).
      subsample - Subsample size for each tree (default 1, range (0,1]).
      featureSubsample - Subsample features for each tree (default 1, range (0,1]).
      lambda - L2 regularization term on weights (default 1).
      alpha - L1 regularization term on weights (default 0).
      nThread - Number of threads to use (default 4).
      verbosity - Set the logging verbosity of the native library.
      seed - RNG seed.
    • XGBoostTrainer

      protected XGBoostTrainer(int numTrees, Map<String,Object> parameters)
      This gives direct access to the XGBoost parameter map.

      It lets you pick things that we haven't exposed like dropout trees, binary classification etc.

      This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.

      Parameters:
      numTrees - Number of trees to boost.
      parameters - A map from string to object, where object can be Number or String.
    • XGBoostTrainer

      protected XGBoostTrainer()
      For olcut.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • createModel

      protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<ml.dmlc.xgboost4j.java.Booster> models, XGBoostOutputConverter<T> converter)
      Creates an XGBoost model from the booster list.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      models - The boosters.
      converter - The converter from XGBoost's output to Tribuo predictions.
      Returns:
      An XGBoost model.
    • copyParams

      protected Map<String,Object> copyParams(Map<String,?> input)
      Returns a copy of the supplied parameter map which has the appropriate type for passing to XGBoost.train.
      Parameters:
      input - The parameter map.
      Returns:
      A (shallow) copy of the supplied map.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<T extends Output<T>>
      Returns:
      The number of train invocations.
    • setInvocationCount

      public void setInvocationCount(int invocationCount)
      Description copied from interface: Trainer
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Specified by:
      setInvocationCount in interface Trainer<T extends Output<T>>
      Parameters:
      invocationCount - the number of invocations of the train method to simulate
    • convertDataset

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts a dataset into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      examples - The examples to convert.
      responseExtractor - The extraction function for the output.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertDataset

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertDataset(Dataset<T> examples) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts a dataset into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      examples - The examples to convert.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertExamples

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an iterable of examples into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      examples - The examples to convert.
      featureMap - The feature id map which supplies the indices.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertExamples

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an iterable of examples into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      examples - The examples to convert.
      featureMap - The feature id map which supplies the indices.
      responseExtractor - The extraction function for the output.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertExample

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an example into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      example - The example to convert.
      featureMap - The feature id map which supplies the indices.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertExample

      protected static <T extends Output<T>> XGBoostTrainer.DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws ml.dmlc.xgboost4j.java.XGBoostError
      Converts an example into a DMatrix.
      Type Parameters:
      T - The type of the output.
      Parameters:
      example - The example to convert.
      featureMap - The feature id map which supplies the indices.
      responseExtractor - The extraction function for the output.
      Returns:
      A DMatrixTuple.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library failed to construct the DMatrix.
    • convertSingleExample

      protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header)
      Writes out the features from an example into the three supplied ArrayLists.

      This is used to transform examples into the right format for an XGBoost call. It's used in both the Classification and Regression XGBoost backends. The ArrayLists must be non-null, and can contain existing values (as this method is called multiple times to build up an arraylist containing all the feature values for a dataset).

      Features with colliding feature ids are summed together.

      Can throw IllegalArgumentException if the Example contains no features.

      Type Parameters:
      T - The type of the example.
      Parameters:
      example - The example to inspect.
      featureMap - The feature map of the model/dataset (used to preserve hash information).
      dataList - The output feature values.
      indicesList - The output indices.
      headersList - The output header position (an integer saying how long each sparse example is).
      header - The current header position.
      Returns:
      The updated header position.
    • convertSparseVector

      protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVector(SparseVector vector) throws ml.dmlc.xgboost4j.java.XGBoostError
      Used when predicting with an externally trained XGBoost model.
      Parameters:
      vector - The features to convert.
      Returns:
      A DMatrix representing the features.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library returns an error state.
    • convertSparseVectors

      protected static ml.dmlc.xgboost4j.java.DMatrix convertSparseVectors(List<SparseVector> vectors) throws ml.dmlc.xgboost4j.java.XGBoostError
      Used when predicting with an externally trained XGBoost model.

      It is assumed all vectors are the same size when passed into this function.

      Parameters:
      vectors - The batch of features to convert.
      Returns:
      A DMatrix representing the batch of features.
      Throws:
      ml.dmlc.xgboost4j.java.XGBoostError - If the native library returns an error state.