Class BaggingTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.ensemble.BaggingTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<T>
- Direct Known Subclasses:
RandomForestTrainer
A Trainer that wraps another trainer and produces a bagged ensemble.
A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the original dataset, combined with an unweighted majority vote.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected EnsembleCombiner<T> protected intprotected SplittableRandomprotected longprotected intFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedFor the configuration system.BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) -
Method Summary
Modifier and TypeMethodDescriptionprotected StringintThe number of times this trainer instance has had it's train method invoked.voidtoString()train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
-
Field Details
-
innerTrainer
@Config(mandatory=true, description="The trainer to use for each ensemble member.") protected Trainer<T extends Output<T>> innerTrainer -
numMembers
@Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers -
seed
-
combiner
@Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.") protected EnsembleCombiner<T extends Output<T>> combiner -
rng
-
trainInvocationCounter
-
-
Constructor Details
-
BaggingTrainer
protected BaggingTrainer()For the configuration system. -
BaggingTrainer
-
BaggingTrainer
-
-
Method Details
-
postConfig
- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
ensembleName
-
toString
-
train
-
trainSingleModel
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) -
getInvocationCount
Description copied from interface:TrainerThe number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCountin interfaceTrainer<T extends Output<T>>- Returns:
- The number of train invocations.
-
getProvenance
-