public abstract class AbstractSGDTrainer<T extends Output<T>,U,V extends Model<T>,X extends FeedForwardParameters> extends Object implements Trainer<T>, WeightedExamples
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
Modifier and Type | Field and Description |
---|---|
protected boolean |
addBias |
protected int |
epochs |
protected int |
loggingInterval |
protected int |
minibatchSize |
protected StochasticGradientOptimiser |
optimiser |
protected SplittableRandom |
rng |
protected long |
seed |
protected boolean |
shuffle |
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
AbstractSGDTrainer(boolean addBias)
Base constructor called by subclass no-args constructors used by OLCUT.
|
protected |
AbstractSGDTrainer(StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
int minibatchSize,
long seed,
boolean addBias)
Constructs an SGD trainer.
|
Modifier and Type | Method and Description |
---|---|
protected abstract V |
createModel(String name,
ModelProvenance provenance,
ImmutableFeatureMap featureMap,
ImmutableOutputInfo<T> outputInfo,
X parameters)
Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.
|
protected abstract X |
createParameters(int numFeatures,
int numOutputs,
SplittableRandom localRNG)
Constructs the trainable parameters object.
|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
protected abstract String |
getModelClassName()
Returns the class name of the model that's produced by this trainer.
|
protected abstract String |
getName()
Returns the default model name.
|
protected abstract SGDObjective<U> |
getObjective()
Returns the objective used by this trainer.
|
TrainerProvenance |
getProvenance() |
protected abstract U |
getTarget(ImmutableOutputInfo<T> outputInfo,
T output)
Extracts the appropriate training time representation from the supplied output.
|
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
void |
setShuffle(boolean shuffle)
Turn on or off shuffling of examples.
|
static <T> void |
shuffleInPlace(SGDVector[] features,
T[] labels,
double[] weights,
SplittableRandom rng)
Shuffles the features, outputs and weights in place.
|
V |
train(Dataset<T> examples)
Trains a predictive model using the examples in the given data set.
|
V |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
@Config(description="The gradient optimiser to use.") protected StochasticGradientOptimiser optimiser
@Config(description="The number of gradient descent epochs.") protected int epochs
@Config(description="Log values after this many updates.") protected int loggingInterval
@Config(description="Minibatch size in SGD.") protected int minibatchSize
@Config(description="Seed for the RNG used to shuffle elements.") protected long seed
@Config(description="Shuffle the data before each epoch. Only turn off for debugging.") protected boolean shuffle
protected final boolean addBias
protected SplittableRandom rng
protected AbstractSGDTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, boolean addBias)
optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.addBias
- Should the model add a bias feature to the feature vector?protected AbstractSGDTrainer(boolean addBias)
addBias
- Should the model add a bias feature to the feature vector?public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public void setShuffle(boolean shuffle)
This isn't exposed in the constructor as it defaults to on. This method should only be used for debugging.
shuffle
- If true shuffle the examples, if false leave them in their current order.public V train(Dataset<T> examples)
Trainer
public V train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
protected abstract U getTarget(ImmutableOutputInfo<T> outputInfo, T output)
outputInfo
- The output info to use.output
- The output to extract.protected abstract SGDObjective<U> getObjective()
protected abstract V createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputInfo, X parameters)
name
- The model name.provenance
- The model provenance.featureMap
- The feature map.outputInfo
- The output info.parameters
- The model parameters.Model
.protected abstract String getModelClassName()
protected abstract String getName()
protected abstract X createParameters(int numFeatures, int numOutputs, SplittableRandom localRNG)
numFeatures
- The number of input features.numOutputs
- The number of output dimensions.localRNG
- The RNG to use for parameter initialisation.public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
public static <T> void shuffleInPlace(SGDVector[] features, T[] labels, double[] weights, SplittableRandom rng)
T
- The output type.features
- Feature array.labels
- Output array.weights
- Weight array.rng
- Random number generator.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.