public class BaggingTrainer<T extends Output<T>> extends Object implements Trainer<T>
A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the original dataset, combined with an unweighted majority vote.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
Modifier and Type | Field and Description |
---|---|
protected EnsembleCombiner<T> |
combiner |
protected Trainer<T> |
innerTrainer |
protected int |
numMembers |
protected SplittableRandom |
rng |
protected long |
seed |
protected int |
trainInvocationCounter |
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
BaggingTrainer()
For the configuration system.
|
|
BaggingTrainer(Trainer<T> trainer,
EnsembleCombiner<T> combiner,
int numMembers) |
|
BaggingTrainer(Trainer<T> trainer,
EnsembleCombiner<T> combiner,
int numMembers,
long seed) |
Modifier and Type | Method and Description |
---|---|
protected String |
ensembleName() |
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig() |
String |
toString() |
Model<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
protected Model<T> |
trainSingleModel(Dataset<T> examples,
ImmutableFeatureMap featureIDs,
ImmutableOutputInfo<T> labelIDs,
SplittableRandom localRNG,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) |
@Config(mandatory=true, description="The trainer to use for each ensemble member.") protected Trainer<T extends Output<T>> innerTrainer
@Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers
@Config(mandatory=true, description="The seed for the RNG.") protected long seed
@Config(mandatory=true, description="The combination function to aggregate each ensemble member\'s outputs.") protected EnsembleCombiner<T extends Output<T>> combiner
protected SplittableRandom rng
protected int trainInvocationCounter
protected BaggingTrainer()
public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers)
public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed)
public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
protected String ensembleName()
public Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.