Class SGD
java.lang.Object
org.tribuo.math.optimisers.SGD
- All Implemented Interfaces:
- com.oracle.labs.mlrg.olcut.config.Configurable,- com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>,- StochasticGradientOptimiser
An implementation of single learning rate SGD and optionally momentum.
 
Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.and for the momentum implementation:
Shallue et al, "Measuring the Effects of Data Parallelism on Neural Network Training" 2018, Arxiv 1811.03600
- 
Nested Class SummaryNested Classes
- 
Field SummaryFieldsModifier and TypeFieldDescriptionprotected doubleThe initial learning rate.protected intThe iteration number, in steps.protected doubleThe scaling factor for the momentum.protected SGD.MomentumShould it use momentum.
- 
Constructor SummaryConstructors
- 
Method SummaryModifier and TypeMethodDescriptionstatic SGDgetLinearDecaySGD(double learningRate) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.static SGDgetLinearDecaySGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenancestatic SGDgetSimpleSGD(double learningRate) Generates an SGD optimiser with a constant learning rate set to learningRate.static SGDgetSimpleSGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.static SGDgetSqrtDecaySGD(double learningRate) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.static SGDgetSqrtDecaySGD(double learningRate, double rho, SGD.Momentum momentumType) Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.voidinitialise(Parameters parameters) Initialises the gradient optimiser.abstract doubleOverride to provide a function which calculates the learning rate.voidreset()Resets the optimiser so it's ready to optimise a newParameters.protected abstract StringsgdType()Override to specify the kind of SGD.Tensor[]Take aTensorarray of gradients and transform them according to the current weight and learning rates.toString()Methods inherited from class java.lang.Objectclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface com.oracle.labs.mlrg.olcut.config.ConfigurablepostConfigMethods inherited from interface org.tribuo.math.StochasticGradientOptimisercopy, finalise
- 
Field Details- 
initialLearningRate@Config(mandatory=true, description="Initial learning rate.") protected double initialLearningRateThe initial learning rate.
- 
useMomentumShould it use momentum.
- 
rho@Config(description="Momentum scaling factor.") protected double rhoThe scaling factor for the momentum.
- 
iterationprotected int iterationThe iteration number, in steps.
 
- 
- 
Constructor Details- 
SGDprotected SGD()For olcut.
 
- 
- 
Method Details- 
initialiseDescription copied from interface:StochasticGradientOptimiserInitialises the gradient optimiser.Configures any learning rate parameters. - Specified by:
- initialisein interface- StochasticGradientOptimiser
- Parameters:
- parameters- The parameters to optimise.
 
- 
stepDescription copied from interface:StochasticGradientOptimiserTake aTensorarray of gradients and transform them according to the current weight and learning rates.Can return the same Tensorarray or a new one.- Specified by:
- stepin interface- StochasticGradientOptimiser
- Parameters:
- updates- An array of gradients.
- weight- The weight for the current gradients.
- Returns:
- A Tensorarray of gradients.
 
- 
learningRatepublic abstract double learningRate()Override to provide a function which calculates the learning rate. The only available information is the iteration count.- Returns:
- The current learning rate.
 
- 
sgdTypeOverride to specify the kind of SGD.- Returns:
- A string representing the SGD type.
 
- 
toString
- 
resetpublic void reset()Description copied from interface:StochasticGradientOptimiserResets the optimiser so it's ready to optimise a newParameters.- Specified by:
- resetin interface- StochasticGradientOptimiser
 
- 
getProvenancepublic com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
- getProvenancein interface- com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
 
- 
getSimpleSGDGenerates an SGD optimiser with a constant learning rate set to learningRate.- Parameters:
- learningRate- The learning rate.
- Returns:
- A constant learning rate SGD.
 
- 
getSimpleSGDGenerates an SGD optimiser with a constant learning rate set to learningRate, with momentum.- Parameters:
- learningRate- The learning rate.
- rho- The momentum drag constant.
- momentumType- Momentum type.
- Returns:
- A constant learning rate SGD with momentum.
 
- 
getLinearDecaySGDGenerates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.The learning rate = initialLearningRate / iteration. - Parameters:
- learningRate- The learning rate.
- Returns:
- A linear decay SGD.
 
- 
getLinearDecaySGDGenerates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.The learning rate = initialLearningRate / iteration. - Parameters:
- learningRate- The learning rate.
- rho- The momentum drag constant.
- momentumType- Momentum type.
- Returns:
- A linear decay SGD with momentum.
 
- 
getSqrtDecaySGDGenerates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.The learning rate = initialLearningRate / sqrt(iteration). - Parameters:
- learningRate- The learning rate.
- Returns:
- A sqrt decay SGD.
 
- 
getSqrtDecaySGDGenerates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.The learning rate = initialLearningRate / sqrt(iteration). - Parameters:
- learningRate- The learning rate.
- rho- The momentum drag constant.
- momentumType- Momentum type.
- Returns:
- A sqrt decay SGD with momentum.
 
 
-