public abstract class SGD extends Object implements StochasticGradientOptimiser
Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.and for the momentum implementation:
Shallue et al, "Measuring the Effects of Data Parallelism on Neural Network Training" 2018, Arxiv 1811.03600
Modifier and Type | Class and Description |
---|---|
static class |
SGD.Momentum
Momentum types.
|
Modifier and Type | Field and Description |
---|---|
protected double |
initialLearningRate |
protected int |
iteration |
protected double |
rho |
protected SGD.Momentum |
useMomentum |
Modifier | Constructor and Description |
---|---|
protected |
SGD()
For olcut.
|
Modifier and Type | Method and Description |
---|---|
static SGD |
getLinearDecaySGD(double learningRate)
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.
|
static SGD |
getLinearDecaySGD(double learningRate,
double rho,
SGD.Momentum momentumType)
Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.
|
com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance |
getProvenance() |
static SGD |
getSimpleSGD(double learningRate)
Generates an SGD optimiser with a constant learning rate set to learningRate.
|
static SGD |
getSimpleSGD(double learningRate,
double rho,
SGD.Momentum momentumType)
Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.
|
static SGD |
getSqrtDecaySGD(double learningRate)
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.
|
static SGD |
getSqrtDecaySGD(double learningRate,
double rho,
SGD.Momentum momentumType)
Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.
|
void |
initialise(Parameters parameters)
Initialises the gradient optimiser.
|
abstract double |
learningRate()
Override to provide a function which calculates the learning rate.
|
void |
reset()
Resets the optimiser so it's ready to optimise a new
Parameters . |
protected abstract String |
sgdType()
Override to specify the kind of SGD.
|
Tensor[] |
step(Tensor[] updates,
double weight)
Take a
Tensor array of gradients and transform them
according to the current weight and learning rates. |
String |
toString() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
copy, finalise
@Config(mandatory=true, description="Initial learning rate.") protected double initialLearningRate
@Config(mandatory=true, description="Momentum type to use.") protected SGD.Momentum useMomentum
@Config(description="Momentum scaling factor.") protected double rho
protected int iteration
public void initialise(Parameters parameters)
StochasticGradientOptimiser
Configures any learning rate parameters.
initialise
in interface StochasticGradientOptimiser
parameters
- The parameters to optimise.public Tensor[] step(Tensor[] updates, double weight)
StochasticGradientOptimiser
Tensor
array of gradients and transform them
according to the current weight and learning rates.
Can return the same Tensor
array or a new one.
step
in interface StochasticGradientOptimiser
updates
- An array of gradients.weight
- The weight for the current gradients.Tensor
array of gradients.public abstract double learningRate()
protected abstract String sgdType()
public void reset()
StochasticGradientOptimiser
Parameters
.reset
in interface StochasticGradientOptimiser
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
public static SGD getSimpleSGD(double learningRate)
learningRate
- The learning rate.public static SGD getSimpleSGD(double learningRate, double rho, SGD.Momentum momentumType)
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.public static SGD getLinearDecaySGD(double learningRate)
learningRate
- The learning rate.public static SGD getLinearDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.public static SGD getSqrtDecaySGD(double learningRate)
learningRate
- The learning rate.public static SGD getSqrtDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
learningRate
- The learning rate.rho
- The momentum drag constant.momentumType
- Momentum type.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.