public final class XGBoostModel<T extends Output<T>> extends Model<T>
Model
which wraps around a XGBoost.Booster.
XGBoost is a fast implementation of gradient boosted decision trees.
Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception.
See:
Chen T, Guestrin C. "XGBoost: A Scalable Tree Boosting System" Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.and for the original algorithm:
Friedman JH. "Greedy Function Approximation: a Gradient Boosting Machine" Annals of statistics, 2001.
Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary on Maven Central.
Modifier and Type | Field and Description |
---|---|
protected List<ml.dmlc.xgboost4j.java.Booster> |
models
The XGBoost4J Boosters.
|
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
protected Model<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
Optional<Excuse<T>> |
getExcuse(Example<T> example)
Generates an excuse for an example.
|
List<ml.dmlc.xgboost4j.java.Booster> |
getInnerModels()
Returns an unmodifiable list containing a copy of each model.
|
List<String[]> |
getModelDump()
Returns the string model dumps from each Booster.
|
Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
List<Prediction<T>> |
predict(Dataset<T> examples)
Uses the model to predict the labels for multiple examples contained in
a data set.
|
Prediction<T> |
predict(Example<T> example)
Uses the model to predict the output for a single example.
|
List<Prediction<T>> |
predict(Iterable<Example<T>> examples)
Uses the model to predict the label for multiple examples.
|
void |
setNumThreads(int threads)
Sets the number of threads to use at prediction time.
|
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, innerPredict, setName, toString, validate
protected transient List<ml.dmlc.xgboost4j.java.Booster> models
public List<ml.dmlc.xgboost4j.java.Booster> getInnerModels()
As XGBoost4J models don't expose a copy constructor this requires serializing each model to a byte array and rebuilding it, and is thus quite expensive.
public void setNumThreads(int threads)
If set to 0 sets nthreads = num hardware threads.
threads
- The new number of threads.public List<Prediction<T>> predict(Dataset<T> examples)
public List<Prediction<T>> predict(Iterable<Example<T>> examples)
public Prediction<T> predict(Example<T> example)
Model
predict does not mutate the example.
Throws IllegalArgumentException
if the example has no features
or no feature overlap with the model.
public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
Model
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
getTopFeatures
in class Model<T extends Output<T>>
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.public List<String[]> getModelDump()
public Optional<Excuse<T>> getExcuse(Example<T> example)
Model
This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.
This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.
The optional is empty if the model does not provide excuses.
protected Model<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.