Class BaggingTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.ensemble.BaggingTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<T>
- Direct Known Subclasses:
ExtraTreesTrainer,RandomForestTrainer
A Trainer that wraps another trainer and produces a bagged ensemble.
A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the original dataset, combined with an unweighted majority vote.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected EnsembleCombiner<T> protected intprotected SplittableRandomprotected longprotected intFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedFor the configuration system.BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) -
Method Summary
Modifier and TypeMethodDescriptionprotected StringintThe number of times this trainer instance has had it's train method invoked.voidUsed by the OLCUT configuration system, and should not be called by external code.toString()train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
-
Field Details
-
innerTrainer
-
numMembers
@Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers -
seed
@Config(mandatory=true, description="The seed for the RNG.") protected long seed -
combiner
@Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.") protected EnsembleCombiner<T extends Output<T>> combiner -
rng
-
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
BaggingTrainer
protected BaggingTrainer()For the configuration system. -
BaggingTrainer
-
BaggingTrainer
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
ensembleName
-
toString
-
train
-
trainSingleModel
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) -
getInvocationCount
public int getInvocationCount()Description copied from interface:TrainerThe number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCountin interfaceTrainer<T extends Output<T>>- Returns:
- The number of train invocations.
-
getProvenance
-