Package org.tribuo.common.sgd
Class AbstractSGDTrainer<T extends Output<T>,U,V extends Model<T>,X extends FeedForwardParameters>
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,X>
- Type Parameters:
T
- The output type.U
- The intermediate representation of the labels.V
- The model type.X
- The parameter type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
AbstractFMTrainer
,AbstractLinearSGDTrainer
public abstract class AbstractSGDTrainer<T extends Output<T>,U,V extends Model<T>,X extends FeedForwardParameters>
extends Object
implements Trainer<T>, WeightedExamples
A trainer for a model which uses SGD.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
-
Field Summary
Modifier and TypeFieldDescriptionprotected final boolean
protected int
protected int
protected int
protected StochasticGradientOptimiser
protected SplittableRandom
protected long
protected boolean
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
AbstractSGDTrainer
(boolean addBias) Base constructor called by subclass no-args constructors used by OLCUT.protected
AbstractSGDTrainer
(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, boolean addBias) Constructs an SGD trainer. -
Method Summary
Modifier and TypeMethodDescriptionprotected abstract V
createModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputInfo, X parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected abstract X
createParameters
(int numFeatures, int numOutputs, SplittableRandom localRNG) Constructs the trainable parameters object.int
The number of times this trainer instance has had it's train method invoked.protected abstract String
Returns the class name of the model that's produced by this trainer.protected abstract String
getName()
Returns the default model name.protected abstract SGDObjective<U>
Returns the objective used by this trainer.protected abstract U
getTarget
(ImmutableOutputInfo<T> outputInfo, T output) Extracts the appropriate training time representation from the supplied output.void
Used by the OLCUT configuration system, and should not be called by external code.void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.void
setShuffle
(boolean shuffle) Turn on or off shuffling of examples.static <T> void
shuffleInPlace
(SGDVector[] features, T[] labels, double[] weights, SplittableRandom rng) Shuffles the features, outputs and weights in place.Trains a predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a predictive model using the examples in the given data set.
-
Field Details
-
optimiser
@Config(description="The gradient optimiser to use.") protected StochasticGradientOptimiser optimiser -
epochs
@Config(description="The number of gradient descent epochs.") protected int epochs -
loggingInterval
@Config(description="Log values after this many updates.") protected int loggingInterval -
minibatchSize
@Config(description="Minibatch size in SGD.") protected int minibatchSize -
seed
@Config(description="Seed for the RNG used to shuffle elements.") protected long seed -
shuffle
@Config(description="Shuffle the data before each epoch. Only turn off for debugging.") protected boolean shuffle -
addBias
protected final boolean addBias -
rng
-
-
Constructor Details
-
AbstractSGDTrainer
protected AbstractSGDTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, boolean addBias) Constructs an SGD trainer.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.addBias
- Should the model add a bias feature to the feature vector?
-
AbstractSGDTrainer
protected AbstractSGDTrainer(boolean addBias) Base constructor called by subclass no-args constructors used by OLCUT.- Parameters:
addBias
- Should the model add a bias feature to the feature vector?
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
setShuffle
public void setShuffle(boolean shuffle) Turn on or off shuffling of examples.This isn't exposed in the constructor as it defaults to on. This method should only be used for debugging.
- Parameters:
shuffle
- If true shuffle the examples, if false leave them in their current order.
-
train
Description copied from interface:Trainer
Trains a predictive model using the examples in the given data set. -
train
public V train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:Trainer
Trains a predictive model using the examples in the given data set. -
train
public V train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Description copied from interface:Trainer
Trains a predictive model using the examples in the given data set.- Specified by:
train
in interfaceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).invocationCount
- The invocation counter that the trainer should be set to before training, which in most cases alters the state of the RNG inside this trainer. If the value is set toTrainer.INCREMENT_INVOCATION_COUNT
then the invocation count is not changed.- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
setInvocationCount
public void setInvocationCount(int invocationCount) Description copied from interface:Trainer
Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Specified by:
setInvocationCount
in interfaceTrainer<T extends Output<T>>
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-
getTarget
Extracts the appropriate training time representation from the supplied output.- Parameters:
outputInfo
- The output info to use.output
- The output to extract.- Returns:
- The training time representation of the output.
-
getObjective
Returns the objective used by this trainer.- Returns:
- The SGDObjective used by this trainer.
-
createModel
protected abstract V createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputInfo, X parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Parameters:
name
- The model name.provenance
- The model provenance.featureMap
- The feature map.outputInfo
- The output info.parameters
- The model parameters.- Returns:
- A new instance of the appropriate subclass of
Model
.
-
getModelClassName
Returns the class name of the model that's produced by this trainer.- Returns:
- The model class name;
-
getName
Returns the default model name.- Returns:
- The default model name.
-
createParameters
Constructs the trainable parameters object.- Parameters:
numFeatures
- The number of input features.numOutputs
- The number of output dimensions.localRNG
- The RNG to use for parameter initialisation.- Returns:
- The trainable parameters.
-
getProvenance
-
shuffleInPlace
public static <T> void shuffleInPlace(SGDVector[] features, T[] labels, double[] weights, SplittableRandom rng) Shuffles the features, outputs and weights in place.- Type Parameters:
T
- The output type.- Parameters:
features
- Feature array.labels
- Output array.weights
- Weight array.rng
- Random number generator.
-