Class XGBoostModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.xgboost.XGBoostModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable

public final class XGBoostModel<T extends Output<T>> extends Model<T>
A Model which wraps around a XGBoost.Booster.

XGBoost is a fast implementation of gradient boosted decision trees.

Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception.

See:

 Chen T, Guestrin C.
 "XGBoost: A Scalable Tree Boosting System"
 Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
 
and for the original algorithm:
 Friedman JH.
 "Greedy Function Approximation: a Gradient Boosting Machine"
 Annals of statistics, 2001.
 

Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary on Maven Central.

See Also:
  • Field Details

    • models

      protected transient List<ml.dmlc.xgboost4j.java.Booster> models
      The XGBoost4J Boosters.
  • Method Details

    • getInnerModels

      public List<ml.dmlc.xgboost4j.java.Booster> getInnerModels()
      Returns an unmodifiable list containing a copy of each model.

      As XGBoost4J models don't expose a copy constructor this requires serializing each model to a byte array and rebuilding it, and is thus quite expensive.

      Returns:
      A copy of all of the models.
    • setNumThreads

      public void setNumThreads(int threads)
      Sets the number of threads to use at prediction time.

      If set to 0 sets nthreads = num hardware threads.

      Parameters:
      threads - The new number of threads.
    • predict

      public List<Prediction<T>> predict(Dataset<T> examples)
      Uses the model to predict the labels for multiple examples contained in a data set.
      Overrides:
      predict in class Model<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples to predict.
      Returns:
      the results of the predictions, in the same order as the data set generates the example.
    • predict

      public List<Prediction<T>> predict(Iterable<Example<T>> examples)
      Uses the model to predict the label for multiple examples.
      Overrides:
      predict in class Model<T extends Output<T>>
      Parameters:
      examples - the examples to predict.
      Returns:
      the results of the prediction, in the same order as the examples.
    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • getTopFeatures

      public Map<String, List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Description copied from class: Model
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • getModelDump

      public List<String[]> getModelDump()
      Returns the string model dumps from each Booster.
      Returns:
      The model dumps.
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Description copied from class: Model
      Generates an excuse for an example.

      This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.

      This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      An optional excuse object. The optional is empty if this model does not provide excuses.
    • copy

      protected Model<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing it's provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.