public class LinearSGDTrainer extends Object implements Trainer<Label>, WeightedExamples
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
DEFAULT_SEED
Constructor and Description |
---|
LinearSGDTrainer(LabelObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
int minibatchSize,
long seed)
Constructs an SGD trainer for a linear model.
|
LinearSGDTrainer(LabelObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
long seed)
Sets the minibatch size to 1.
|
LinearSGDTrainer(LabelObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
long seed)
Sets the minibatch size to 1 and the logging interval to 1000.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig() |
void |
setShuffle(boolean shuffle)
Turn on or off shuffling of examples.
|
String |
toString() |
Model<Label> |
train(Dataset<Label> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed)
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed)
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed)
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public void setShuffle(boolean shuffle)
This isn't exposed in the constructor as it defaults to on. This method should only be used for debugging.
shuffle
- If true shuffle the examples, if false leave them in their current order.public Model<Label> train(Dataset<Label> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<Label>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.