Class XGBoostClassificationTrainer
java.lang.Object
org.tribuo.common.xgboost.XGBoostTrainer<Label>
org.tribuo.classification.xgboost.XGBoostClassificationTrainer
- All Implemented Interfaces:
- com.oracle.labs.mlrg.olcut.config.Configurable,- com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,- Trainer<Label>,- WeightedExamples
A 
Trainer which wraps the XGBoost training procedure.
 This only exposes a few of XGBoost's training parameters.
It uses pthreads outside of the JVM to parallelise the computation.
See:
Chen T, Guestrin C. "XGBoost: A Scalable Tree Boosting System" Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.and for the original algorithm:
Friedman JH. "Greedy Function Approximation: a Gradient Boosting Machine" Annals of statistics, 2001.
Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary on Maven Central.
- 
Nested Class SummaryNested classes/interfaces inherited from class org.tribuo.common.xgboost.XGBoostTrainerXGBoostTrainer.BoosterType, XGBoostTrainer.DMatrixTuple<T extends Output<T>>, XGBoostTrainer.XGBoostTrainerProvenance
- 
Field SummaryFields inherited from class org.tribuo.common.xgboost.XGBoostTrainernumTrees, parameters, trainInvocationCounterFields inherited from interface org.tribuo.TrainerDEFAULT_SEED
- 
Constructor SummaryConstructorsModifierConstructorDescriptionprotectedFor olcut.XGBoostClassificationTrainer(int numTrees) XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.XGBoostClassificationTrainer(int numTrees, int numThreads, boolean silent) XGBoostClassificationTrainer(int numTrees, Map<String, Object> parameters) This gives direct access to the XGBoost parameter map.
- 
Method SummaryModifier and TypeMethodDescriptionvoidUsed by the OLCUT configuration system, and should not be called by external code.train(Dataset<Label> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.Methods inherited from class org.tribuo.common.xgboost.XGBoostTrainerconvertDataset, convertDataset, convertExample, convertExample, convertExamples, convertExamples, convertSingleExample, convertSparseVector, convertSparseVectors, createModel, getInvocationCount, toString
- 
Constructor Details- 
XGBoostClassificationTrainer
- 
XGBoostClassificationTrainer
- 
XGBoostClassificationTrainerpublic XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) Create an XGBoost trainer.- Parameters:
- numTrees- Number of trees to boost.
- eta- Step size shrinkage parameter (default 0.3, range [0,1]).
- gamma- Minimum loss reduction to make a split (default 0, range [0,inf]).
- maxDepth- Maximum tree depth (default 6, range [1,inf]).
- minChildWeight- Minimum sum of instance weights needed in a leaf (default 1, range [0, inf]).
- subsample- Subsample size for each tree (default 1, range (0,1]).
- featureSubsample- Subsample features for each tree (default 1, range (0,1]).
- lambda- L2 regularization term on weights (default 1).
- alpha- L1 regularization term on weights (default 0).
- nThread- Number of threads to use (default 4).
- silent- Silence the training output text.
- seed- RNG seed.
 
- 
XGBoostClassificationTrainerThis gives direct access to the XGBoost parameter map.It lets you pick things that we haven't exposed like dropout trees, binary classification etc. This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results. - Parameters:
- numTrees- Number of trees to boost.
- parameters- A map from string to object, where object can be Number or String.
 
- 
XGBoostClassificationTrainerprotected XGBoostClassificationTrainer()For olcut.
 
- 
- 
Method Details- 
postConfigUsed by the OLCUT configuration system, and should not be called by external code.- Specified by:
- postConfigin interface- com.oracle.labs.mlrg.olcut.config.Configurable
- Overrides:
- postConfigin class- XGBoostTrainer<Label>
 
- 
trainpublic XGBoostModel<Label> train(Dataset<Label> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:TrainerTrains a predictive model using the examples in the given data set.- Parameters:
- examples- the data set containing the examples.
- runProvenance- Training run specific provenance (e.g., fold number).
- Returns:
- a predictive model that can be used to generate predictions for new examples.
 
- 
getProvenance
 
-