Class AbstractSGDTrainer<T extends Output<T>,U,V extends Model<T>,X extends FeedForwardParameters>

java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,X>
Type Parameters:
T - The output type.
U - The intermediate representation of the labels.
V - The model type.
X - The parameter type.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
AbstractFMTrainer, AbstractLinearSGDTrainer

public abstract class AbstractSGDTrainer<T extends Output<T>,U,V extends Model<T>,X extends FeedForwardParameters> extends Object implements Trainer<T>, WeightedExamples
A trainer for a model which uses SGD.

See:

 Bottou L.
 "Large-Scale Machine Learning with Stochastic Gradient Descent"
 Proceedings of COMPSTAT, 2010.
 
  • Field Details

    • optimiser

      @Config(description="The gradient optimiser to use.") protected StochasticGradientOptimiser optimiser
    • epochs

      @Config(description="The number of gradient descent epochs.") protected int epochs
    • loggingInterval

      @Config(description="Log values after this many updates.") protected int loggingInterval
    • minibatchSize

      @Config(description="Minibatch size in SGD.") protected int minibatchSize
    • seed

      @Config(description="Seed for the RNG used to shuffle elements.") protected long seed
    • shuffle

      @Config(description="Shuffle the data before each epoch. Only turn off for debugging.") protected boolean shuffle
    • addBias

      protected final boolean addBias
    • rng

      protected SplittableRandom rng
  • Constructor Details

    • AbstractSGDTrainer

      protected AbstractSGDTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, boolean addBias)
      Constructs an SGD trainer.
      Parameters:
      optimiser - The gradient optimiser to use.
      epochs - The number of epochs (complete passes through the training data).
      loggingInterval - Log the loss after this many iterations. If -1 don't log anything.
      minibatchSize - The size of any minibatches.
      seed - A seed for the random number generator, used to shuffle the examples before each epoch.
      addBias - Should the model add a bias feature to the feature vector?
    • AbstractSGDTrainer

      protected AbstractSGDTrainer(boolean addBias)
      Base constructor called by subclass no-args constructors used by OLCUT.
      Parameters:
      addBias - Should the model add a bias feature to the feature vector?
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • setShuffle

      public void setShuffle(boolean shuffle)
      Turn on or off shuffling of examples.

      This isn't exposed in the constructor as it defaults to on. This method should only be used for debugging.

      Parameters:
      shuffle - If true shuffle the examples, if false leave them in their current order.
    • train

      public V train(Dataset<T> examples)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      public V train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      public V train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      invocationCount - The invocation counter that the trainer should be set to before training, which in most cases alters the state of the RNG inside this trainer. If the value is set to Trainer.INCREMENT_INVOCATION_COUNT then the invocation count is not changed.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<T extends Output<T>>
      Returns:
      The number of train invocations.
    • setInvocationCount

      public void setInvocationCount(int invocationCount)
      Description copied from interface: Trainer
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Specified by:
      setInvocationCount in interface Trainer<T extends Output<T>>
      Parameters:
      invocationCount - the number of invocations of the train method to simulate
    • getTarget

      protected abstract U getTarget(ImmutableOutputInfo<T> outputInfo, T output)
      Extracts the appropriate training time representation from the supplied output.
      Parameters:
      outputInfo - The output info to use.
      output - The output to extract.
      Returns:
      The training time representation of the output.
    • getObjective

      protected abstract SGDObjective<U> getObjective()
      Returns the objective used by this trainer.
      Returns:
      The SGDObjective used by this trainer.
    • createModel

      protected abstract V createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputInfo, X parameters)
      Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureMap - The feature map.
      outputInfo - The output info.
      parameters - The model parameters.
      Returns:
      A new instance of the appropriate subclass of Model.
    • getModelClassName

      protected abstract String getModelClassName()
      Returns the class name of the model that's produced by this trainer.
      Returns:
      The model class name;
    • getName

      protected abstract String getName()
      Returns the default model name.
      Returns:
      The default model name.
    • createParameters

      protected abstract X createParameters(int numFeatures, int numOutputs, SplittableRandom localRNG)
      Constructs the trainable parameters object.
      Parameters:
      numFeatures - The number of input features.
      numOutputs - The number of output dimensions.
      localRNG - The RNG to use for parameter initialisation.
      Returns:
      The trainable parameters.
    • getProvenance

      public TrainerProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<T extends Output<T>>
    • shuffleInPlace

      public static <T> void shuffleInPlace(SGDVector[] features, T[] labels, double[] weights, SplittableRandom rng)
      Shuffles the features, outputs and weights in place.
      Type Parameters:
      T - The output type.
      Parameters:
      features - Feature array.
      labels - Output array.
      weights - Weight array.
      rng - Random number generator.