Package org.tribuo.math.optimisers
Class SGD
java.lang.Object
org.tribuo.math.optimisers.SGD
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,StochasticGradientOptimiser
An implementation of single learning rate SGD and optionally momentum.
Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.and for the momentum implementation:
Shallue et al, "Measuring the Effects of Data Parallelism on Neural Network Training" 2018, Arxiv 1811.03600
-
Nested Class Summary
-
Field Summary
Modifier and TypeFieldDescriptionprotected double
protected int
protected double
protected SGD.Momentum
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionstatic SGD
getLinearDecaySGD
(double learningRate) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.static SGD
getLinearDecaySGD
(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
static SGD
getSimpleSGD
(double learningRate) Generates an SGD optimiser with a constant learning rate set to learningRate.static SGD
getSimpleSGD
(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.static SGD
getSqrtDecaySGD
(double learningRate) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.static SGD
getSqrtDecaySGD
(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.void
initialise
(Parameters parameters) Initialises the gradient optimiser.abstract double
Override to provide a function which calculates the learning rate.void
reset()
Resets the optimiser so it's ready to optimise a newParameters
.protected abstract String
sgdType()
Override to specify the kind of SGD.Tensor[]
Take aTensor
array of gradients and transform them according to the current weight and learning rates.toString()
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.math.StochasticGradientOptimiser
copy, finalise
-
Field Details
-
initialLearningRate
@Config(mandatory=true, description="Initial learning rate.") protected double initialLearningRate -
useMomentum
-
rho
@Config(description="Momentum scaling factor.") protected double rho -
iteration
protected int iteration
-
-
Constructor Details
-
SGD
protected SGD()For olcut.
-
-
Method Details
-
initialise
Description copied from interface:StochasticGradientOptimiser
Initialises the gradient optimiser.Configures any learning rate parameters.
- Specified by:
initialise
in interfaceStochasticGradientOptimiser
- Parameters:
parameters
- The parameters to optimise.
-
step
Description copied from interface:StochasticGradientOptimiser
Take aTensor
array of gradients and transform them according to the current weight and learning rates.Can return the same
Tensor
array or a new one.- Specified by:
step
in interfaceStochasticGradientOptimiser
- Parameters:
updates
- An array of gradients.weight
- The weight for the current gradients.- Returns:
- A
Tensor
array of gradients.
-
learningRate
public abstract double learningRate()Override to provide a function which calculates the learning rate. The only available information is the iteration count.- Returns:
- The current learning rate.
-
sgdType
Override to specify the kind of SGD.- Returns:
- A string representing the SGD type.
-
toString
-
reset
public void reset()Description copied from interface:StochasticGradientOptimiser
Resets the optimiser so it's ready to optimise a newParameters
.- Specified by:
reset
in interfaceStochasticGradientOptimiser
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
-
getSimpleSGD
Generates an SGD optimiser with a constant learning rate set to learningRate.- Parameters:
learningRate
- The learning rate.- Returns:
- A constant learning rate SGD.
-
getSimpleSGD
Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.- Parameters:
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.- Returns:
- A constant learning rate SGD with momentum.
-
getLinearDecaySGD
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate. The learning rate = initialLearningRate / iteration.- Parameters:
learningRate
- The learning rate.- Returns:
- A linear decay SGD.
-
getLinearDecaySGD
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum. The learning rate = initialLearningRate / iteration.- Parameters:
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.- Returns:
- A linear decay SGD with momentum.
-
getSqrtDecaySGD
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate. The learning rate = initialLearningRate / sqrt(iteration).- Parameters:
learningRate
- The learning rate.- Returns:
- A sqrt decay SGD.
-
getSqrtDecaySGD
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum. The learning rate = initialLearningRate / sqrt(iteration).- Parameters:
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.- Returns:
- A sqrt decay SGD with momentum.
-