Class SGD
java.lang.Object
org.tribuo.math.optimisers.SGD
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>,StochasticGradientOptimiser
An implementation of single learning rate SGD and optionally momentum.
Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.and for the momentum implementation:
Shallue et al, "Measuring the Effects of Data Parallelism on Neural Network Training" 2018, Arxiv 1811.03600
-
Nested Class Summary
Nested Classes -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected doubleprotected intprotected doubleprotected SGD.Momentum -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionstatic SGDgetLinearDecaySGD(double learningRate) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.static SGDgetLinearDecaySGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenancestatic SGDgetSimpleSGD(double learningRate) Generates an SGD optimiser with a constant learning rate set to learningRate.static SGDgetSimpleSGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.static SGDgetSqrtDecaySGD(double learningRate) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.static SGDgetSqrtDecaySGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.voidinitialise(Parameters parameters) Initialises the gradient optimiser.abstract doubleOverride to provide a function which calculates the learning rate.voidreset()Resets the optimiser so it's ready to optimise a newParameters.protected abstract StringsgdType()Override to specify the kind of SGD.Tensor[]Take aTensorarray of gradients and transform them according to the current weight and learning rates.toString()Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfigMethods inherited from interface org.tribuo.math.StochasticGradientOptimiser
copy, finalise
-
Field Details
-
initialLearningRate
@Config(mandatory=true, description="Initial learning rate.") protected double initialLearningRate -
useMomentum
-
rho
@Config(description="Momentum scaling factor.") protected double rho -
iteration
protected int iteration
-
-
Constructor Details
-
SGD
protected SGD()For olcut.
-
-
Method Details
-
initialise
Description copied from interface:StochasticGradientOptimiserInitialises the gradient optimiser.Configures any learning rate parameters.
- Specified by:
initialisein interfaceStochasticGradientOptimiser- Parameters:
parameters- The parameters to optimise.
-
step
Description copied from interface:StochasticGradientOptimiserTake aTensorarray of gradients and transform them according to the current weight and learning rates.Can return the same
Tensorarray or a new one.- Specified by:
stepin interfaceStochasticGradientOptimiser- Parameters:
updates- An array of gradients.weight- The weight for the current gradients.- Returns:
- A
Tensorarray of gradients.
-
learningRate
public abstract double learningRate()Override to provide a function which calculates the learning rate. The only available information is the iteration count.- Returns:
- The current learning rate.
-
sgdType
Override to specify the kind of SGD.- Returns:
- A string representing the SGD type.
-
toString
-
reset
public void reset()Description copied from interface:StochasticGradientOptimiserResets the optimiser so it's ready to optimise a newParameters.- Specified by:
resetin interfaceStochasticGradientOptimiser
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
getProvenancein interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
-
getSimpleSGD
Generates an SGD optimiser with a constant learning rate set to learningRate.- Parameters:
learningRate- The learning rate.- Returns:
- A constant learning rate SGD.
-
getSimpleSGD
Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.- Parameters:
learningRate- The learning rate.rho- The momentum drag constant.momentumType- Momentum type.- Returns:
- A constant learning rate SGD with momentum.
-
getLinearDecaySGD
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate. The learning rate = initialLearningRate / iteration.- Parameters:
learningRate- The learning rate.- Returns:
- A linear decay SGD.
-
getLinearDecaySGD
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum. The learning rate = initialLearningRate / iteration.- Parameters:
learningRate- The learning rate.rho- The momentum drag constant.momentumType- Momentum type.- Returns:
- A linear decay SGD with momentum.
-
getSqrtDecaySGD
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate. The learning rate = initialLearningRate / sqrt(iteration).- Parameters:
learningRate- The learning rate.- Returns:
- A sqrt decay SGD.
-
getSqrtDecaySGD
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum. The learning rate = initialLearningRate / sqrt(iteration).- Parameters:
learningRate- The learning rate.rho- The momentum drag constant.momentumType- Momentum type.- Returns:
- A sqrt decay SGD with momentum.
-