Class SGD

java.lang.Object
org.tribuo.math.optimisers.SGD
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, StochasticGradientOptimiser

public abstract class SGD extends Object implements StochasticGradientOptimiser
An implementation of single learning rate SGD and optionally momentum.

Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.

See:

 Bottou L.
 "Large-Scale Machine Learning with Stochastic Gradient Descent"
 Proceedings of COMPSTAT, 2010.
 
and for the momentum implementation:
 Shallue et al,
 "Measuring the Effects of Data Parallelism on Neural Network Training"
 2018, Arxiv 1811.03600
 
  • Nested Class Summary

    Nested Classes
    Modifier and Type
    Class
    Description
    static enum 
    Momentum types.
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    protected double
    The initial learning rate.
    protected int
    The iteration number, in steps.
    protected double
    The scaling factor for the momentum.
    protected SGD.Momentum
    Should it use momentum.
  • Constructor Summary

    Constructors
    Modifier
    Constructor
    Description
    protected
    SGD()
    For olcut.
  • Method Summary

    Modifier and Type
    Method
    Description
    static SGD
    getLinearDecaySGD(double learningRate)
    Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.
    static SGD
    getLinearDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
    Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.
    com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
     
    static SGD
    getSimpleSGD(double learningRate)
    Generates an SGD optimiser with a constant learning rate set to learningRate.
    static SGD
    getSimpleSGD(double learningRate, double rho, SGD.Momentum momentumType)
    Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.
    static SGD
    getSqrtDecaySGD(double learningRate)
    Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.
    static SGD
    getSqrtDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
    Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.
    void
    initialise(Parameters parameters)
    Initialises the gradient optimiser.
    abstract double
    Override to provide a function which calculates the learning rate.
    void
    Resets the optimiser so it's ready to optimise a new Parameters.
    protected abstract String
    Override to specify the kind of SGD.
    step(Tensor[] updates, double weight)
    Take a Tensor array of gradients and transform them according to the current weight and learning rates.
     

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable

    postConfig

    Methods inherited from interface org.tribuo.math.StochasticGradientOptimiser

    copy, finalise
  • Field Details

    • initialLearningRate

      @Config(mandatory=true, description="Initial learning rate.") protected double initialLearningRate
      The initial learning rate.
    • useMomentum

      @Config(mandatory=true, description="Momentum type to use.") protected SGD.Momentum useMomentum
      Should it use momentum.
    • rho

      @Config(description="Momentum scaling factor.") protected double rho
      The scaling factor for the momentum.
    • iteration

      protected int iteration
      The iteration number, in steps.
  • Constructor Details

    • SGD

      protected SGD()
      For olcut.
  • Method Details

    • initialise

      public void initialise(Parameters parameters)
      Description copied from interface: StochasticGradientOptimiser
      Initialises the gradient optimiser.

      Configures any learning rate parameters.

      Specified by:
      initialise in interface StochasticGradientOptimiser
      Parameters:
      parameters - The parameters to optimise.
    • step

      public Tensor[] step(Tensor[] updates, double weight)
      Description copied from interface: StochasticGradientOptimiser
      Take a Tensor array of gradients and transform them according to the current weight and learning rates.

      Can return the same Tensor array or a new one.

      Specified by:
      step in interface StochasticGradientOptimiser
      Parameters:
      updates - An array of gradients.
      weight - The weight for the current gradients.
      Returns:
      A Tensor array of gradients.
    • learningRate

      public abstract double learningRate()
      Override to provide a function which calculates the learning rate. The only available information is the iteration count.
      Returns:
      The current learning rate.
    • sgdType

      protected abstract String sgdType()
      Override to specify the kind of SGD.
      Returns:
      A string representing the SGD type.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • reset

      public void reset()
      Description copied from interface: StochasticGradientOptimiser
      Resets the optimiser so it's ready to optimise a new Parameters.
      Specified by:
      reset in interface StochasticGradientOptimiser
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
    • getSimpleSGD

      public static SGD getSimpleSGD(double learningRate)
      Generates an SGD optimiser with a constant learning rate set to learningRate.
      Parameters:
      learningRate - The learning rate.
      Returns:
      A constant learning rate SGD.
    • getSimpleSGD

      public static SGD getSimpleSGD(double learningRate, double rho, SGD.Momentum momentumType)
      Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.
      Parameters:
      learningRate - The learning rate.
      rho - The momentum drag constant.
      momentumType - Momentum type.
      Returns:
      A constant learning rate SGD with momentum.
    • getLinearDecaySGD

      public static SGD getLinearDecaySGD(double learningRate)
      Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.

      The learning rate = initialLearningRate / iteration.

      Parameters:
      learningRate - The learning rate.
      Returns:
      A linear decay SGD.
    • getLinearDecaySGD

      public static SGD getLinearDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
      Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.

      The learning rate = initialLearningRate / iteration.

      Parameters:
      learningRate - The learning rate.
      rho - The momentum drag constant.
      momentumType - Momentum type.
      Returns:
      A linear decay SGD with momentum.
    • getSqrtDecaySGD

      public static SGD getSqrtDecaySGD(double learningRate)
      Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.

      The learning rate = initialLearningRate / sqrt(iteration).

      Parameters:
      learningRate - The learning rate.
      Returns:
      A sqrt decay SGD.
    • getSqrtDecaySGD

      public static SGD getSqrtDecaySGD(double learningRate, double rho, SGD.Momentum momentumType)
      Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.

      The learning rate = initialLearningRate / sqrt(iteration).

      Parameters:
      learningRate - The learning rate.
      rho - The momentum drag constant.
      momentumType - Momentum type.
      Returns:
      A sqrt decay SGD with momentum.