Class BaggingTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.ensemble.BaggingTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
- Direct Known Subclasses:
ExtraTreesTrainer
,RandomForestTrainer
A Trainer that wraps another trainer and produces a bagged ensemble.
A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the original dataset, combined with an unweighted majority vote.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected EnsembleCombiner
<T> protected int
protected SplittableRandom
protected long
protected int
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsModifierConstructorDescriptionprotected
For the configuration system.BaggingTrainer
(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) BaggingTrainer
(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) -
Method Summary
Modifier and TypeMethodDescriptionprotected String
int
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.trainSingleModel
(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
-
Field Details
-
innerTrainer
-
numMembers
@Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers -
seed
@Config(mandatory=true, description="The seed for the RNG.") protected long seed -
combiner
@Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.") protected EnsembleCombiner<T extends Output<T>> combiner -
rng
-
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
BaggingTrainer
protected BaggingTrainer()For the configuration system. -
BaggingTrainer
-
BaggingTrainer
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
ensembleName
-
toString
-
train
-
trainSingleModel
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) -
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
getProvenance
-