public class AdaGrad extends Object implements StochasticGradientOptimiser
Creates one copy of the parameters to store learning rates.
See:
Duchi, J., Hazan, E., and Singer, Y. "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization" Journal of Machine Learning Research, 2012, 2121-2159.
Constructor and Description |
---|
AdaGrad(double initialLearningRate)
Sets epsilon to 1e-6.
|
AdaGrad(double initialLearningRate,
double epsilon) |
Modifier and Type | Method and Description |
---|---|
AdaGrad |
copy()
Copies a gradient optimiser with it's configuration.
|
com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance |
getProvenance() |
void |
initialise(Parameters parameters)
Initialises the gradient optimiser.
|
void |
reset()
Resets the optimiser so it's ready to optimise a new
Parameters . |
Tensor[] |
step(Tensor[] updates,
double weight)
Take a
Tensor array of gradients and transform them
according to the current weight and learning rates. |
String |
toString() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
finalise
public AdaGrad(double initialLearningRate, double epsilon)
public AdaGrad(double initialLearningRate)
initialLearningRate
- The learning rate.public void initialise(Parameters parameters)
StochasticGradientOptimiser
Configures any learning rate parameters.
initialise
in interface StochasticGradientOptimiser
parameters
- The parameters to optimise.public Tensor[] step(Tensor[] updates, double weight)
StochasticGradientOptimiser
Tensor
array of gradients and transform them
according to the current weight and learning rates.
Can return the same Tensor
array or a new one.
step
in interface StochasticGradientOptimiser
updates
- An array of gradients.weight
- The weight for the current gradients.Tensor
array of gradients.public void reset()
StochasticGradientOptimiser
Parameters
.reset
in interface StochasticGradientOptimiser
public AdaGrad copy()
StochasticGradientOptimiser
copy
in interface StochasticGradientOptimiser
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.