Package org.tribuo

Interface Trainer<T extends Output<T>>

Type Parameters:
T - the type of the Output in the examples
All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
All Known Subinterfaces:
DecisionTreeTrainer<T>, IncrementalTrainer<T,U>, SparseTrainer<T>
All Known Implementing Classes:
AbstractCARTTrainer, AbstractFMTrainer, AbstractLinearSGDTrainer, AbstractSGDTrainer, AdaBoostTrainer, BaggingTrainer, CARTClassificationTrainer, CARTJointRegressionTrainer, CARTRegressionTrainer, CCEnsembleTrainer, ClassifierChainTrainer, DummyClassifierTrainer, DummyRegressionTrainer, ElasticNetCDTrainer, ExtraTreesTrainer, FMClassificationTrainer, FMMultiLabelTrainer, FMRegressionTrainer, HashingTrainer, HdbscanTrainer, IndependentMultiLabelTrainer, KernelSVMTrainer, KMeansTrainer, KNNTrainer, LARSLassoTrainer, LARSTrainer, LibLinearAnomalyTrainer, LibLinearClassificationTrainer, LibLinearRegressionTrainer, LibLinearTrainer, LibSVMAnomalyTrainer, LibSVMClassificationTrainer, LibSVMRegressionTrainer, LibSVMTrainer, LinearSGDTrainer, LinearSGDTrainer, LinearSGDTrainer, LogisticRegressionTrainer, MultinomialNaiveBayesTrainer, RandomForestTrainer, SkeletalIndependentRegressionSparseTrainer, SkeletalIndependentRegressionTrainer, SLMTrainer, TensorFlowTrainer, TransformTrainer, XGBoostClassificationTrainer, XGBoostRegressionTrainer, XGBoostTrainer

public interface Trainer<T extends Output<T>> extends com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
An interface for things that can train predictive models.
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    static final long
    Default seed used to initialise RNGs.
    static final int
    When training a model, passing this value will inform the trainer to simply increment the invocation count rather than set a new one
  • Method Summary

    Modifier and Type
    Method
    Description
    int
    The number of times this trainer instance has had it's train method invoked.
    default void
    setInvocationCount(int invocationCount)
    Set the internal state of the trainer to the provided number of invocations of the train method.
    default Model<T>
    train(Dataset<T> examples)
    Trains a predictive model using the examples in the given data set.
    train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
    Trains a predictive model using the examples in the given data set.
    default Model<T>
    train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
    Trains a predictive model using the examples in the given data set.

    Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable

    postConfig

    Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable

    getProvenance
  • Field Details

    • DEFAULT_SEED

      static final long DEFAULT_SEED
      Default seed used to initialise RNGs.
      See Also:
    • INCREMENT_INVOCATION_COUNT

      static final int INCREMENT_INVOCATION_COUNT
      When training a model, passing this value will inform the trainer to simply increment the invocation count rather than set a new one
      See Also:
  • Method Details

    • train

      default Model<T> train(Dataset<T> examples)
      Trains a predictive model using the examples in the given data set.
      Parameters:
      examples - the data set containing the examples.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Trains a predictive model using the examples in the given data set.
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      default Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
      Trains a predictive model using the examples in the given data set.
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      invocationCount - The invocation counter that the trainer should be set to before training, which in most cases alters the state of the RNG inside this trainer. If the value is set to INCREMENT_INVOCATION_COUNT then the invocation count is not changed.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • getInvocationCount

      int getInvocationCount()
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Returns:
      The number of train invocations.
    • setInvocationCount

      default void setInvocationCount(int invocationCount)
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Parameters:
      invocationCount - the number of invocations of the train method to simulate