Package org.tribuo
Interface Trainer<T extends Output<T>>
- Type Parameters:
T
- the type of theOutput
in the examples
- All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
- All Known Subinterfaces:
DecisionTreeTrainer<T>
,IncrementalTrainer<T,
,U> SparseTrainer<T>
- All Known Implementing Classes:
AbstractCARTTrainer
,AbstractFMTrainer
,AbstractLinearSGDTrainer
,AbstractSGDTrainer
,AdaBoostTrainer
,BaggingTrainer
,CARTClassificationTrainer
,CARTJointRegressionTrainer
,CARTRegressionTrainer
,CCEnsembleTrainer
,ClassifierChainTrainer
,DummyClassifierTrainer
,DummyRegressionTrainer
,ElasticNetCDTrainer
,ExtraTreesTrainer
,FMClassificationTrainer
,FMMultiLabelTrainer
,FMRegressionTrainer
,HashingTrainer
,HdbscanTrainer
,IndependentMultiLabelTrainer
,KernelSVMTrainer
,KMeansTrainer
,KNNTrainer
,LARSLassoTrainer
,LARSTrainer
,LibLinearAnomalyTrainer
,LibLinearClassificationTrainer
,LibLinearRegressionTrainer
,LibLinearTrainer
,LibSVMAnomalyTrainer
,LibSVMClassificationTrainer
,LibSVMRegressionTrainer
,LibSVMTrainer
,LinearSGDTrainer
,LinearSGDTrainer
,LinearSGDTrainer
,LogisticRegressionTrainer
,MultinomialNaiveBayesTrainer
,RandomForestTrainer
,SkeletalIndependentRegressionSparseTrainer
,SkeletalIndependentRegressionTrainer
,SLMTrainer
,TensorFlowTrainer
,TransformTrainer
,XGBoostClassificationTrainer
,XGBoostRegressionTrainer
,XGBoostTrainer
public interface Trainer<T extends Output<T>>
extends com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
An interface for things that can train predictive models.
-
Field Summary
Modifier and TypeFieldDescriptionstatic final long
Default seed used to initialise RNGs.static final int
When training a model, passing this value will inform the trainer to simply increment the invocation count rather than set a new one -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.default void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.Trains a predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a predictive model using the examples in the given data set.Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Field Details
-
DEFAULT_SEED
static final long DEFAULT_SEEDDefault seed used to initialise RNGs.- See Also:
-
INCREMENT_INVOCATION_COUNT
static final int INCREMENT_INVOCATION_COUNTWhen training a model, passing this value will inform the trainer to simply increment the invocation count rather than set a new one- See Also:
-
-
Method Details
-
train
Trains a predictive model using the examples in the given data set.- Parameters:
examples
- the data set containing the examples.- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
train
Model<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
train
default Model<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a predictive model using the examples in the given data set.- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).invocationCount
- The invocation counter that the trainer should be set to before training, which in most cases alters the state of the RNG inside this trainer. If the value is set toINCREMENT_INVOCATION_COUNT
then the invocation count is not changed.- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
int getInvocationCount()The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Returns:
- The number of train invocations.
-
setInvocationCount
default void setInvocationCount(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-