001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.DenseMatrix;
025import org.tribuo.math.la.DenseVector;
026import org.tribuo.math.la.Tensor;
027import org.tribuo.math.optimisers.util.ShrinkingMatrix;
028import org.tribuo.math.optimisers.util.ShrinkingVector;
029
030import java.util.function.DoubleUnaryOperator;
031import java.util.logging.Logger;
032
033/**
034 * An implementation of the RMSProp gradient optimiser.
035 * <p>
036 * Creates one copy of the parameters to store learning rates.
037 * Follows the Keras implementation.
038 * <p>
039 * See:
040 * <pre>
041 * Tieleman, T. and Hinton, G.
042 * Lecture 6.5 - RMSProp, COURSERA: Neural Networks for Machine Learning.
043 * Technical report, 2012.
044 * </pre>
045 */
046public class RMSProp implements StochasticGradientOptimiser {
047    private static final Logger logger = Logger.getLogger(RMSProp.class.getName());
048
049    @Config(mandatory = true,description="Learning rate to scale the gradients by.")
050    private double initialLearningRate;
051
052    @Config(description="Momentum parameter.")
053    private double rho = 0.9;
054
055    @Config(description="Epsilon for numerical stability.")
056    private double epsilon = 1e-8;
057
058    @Config(description="Decay factor for the momentum.")
059    private double decay = 0.0;
060
061    private double invRho;
062
063    private int iteration = 0;
064
065    private Tensor[] gradsSquared;
066
067    private DoubleUnaryOperator square;
068
069    public RMSProp(double initialLearningRate, double rho, double epsilon, double decay) {
070        this.initialLearningRate = initialLearningRate;
071        this.rho = rho;
072        this.epsilon = epsilon;
073        this.decay = decay;
074        this.iteration = 0;
075        postConfig();
076    }
077
078    public RMSProp(double initialLearningRate, double rho) {
079        this(initialLearningRate,rho,1e-8,0.0);
080    }
081
082    /**
083     * For olcut.
084     */
085    private RMSProp() { }
086
087    /**
088     * Used by the OLCUT configuration system, and should not be called by external code.
089     */
090    @Override
091    public void postConfig() {
092        this.invRho = 1.0 - rho;
093        this.square = (double a) -> invRho*a*a;
094    }
095
096    @Override
097    public void initialise(Parameters parameters) {
098        gradsSquared = parameters.getEmptyCopy();
099        for (int i = 0; i < gradsSquared.length; i++) {
100            if (gradsSquared[i] instanceof DenseVector) {
101                gradsSquared[i] = new ShrinkingVector(((DenseVector) gradsSquared[i]), invRho, false);
102            } else if (gradsSquared[i] instanceof DenseMatrix) {
103                gradsSquared[i] = new ShrinkingMatrix(((DenseMatrix) gradsSquared[i]), invRho, false);
104            } else {
105                throw new IllegalStateException("Unknown Tensor subclass");
106            }
107        }
108    }
109
110    @Override
111    public Tensor[] step(Tensor[] updates, double weight) {
112        double learningRate = initialLearningRate / (1 + decay * iteration);
113        //lifting lambdas out of the for loop until JDK-8183316 is fixed.
114        DoubleUnaryOperator scale = (double a) -> weight * learningRate / (epsilon + Math.sqrt(a));
115        for (int i = 0; i < updates.length; i++) {
116            Tensor curGradsSquared = gradsSquared[i];
117            Tensor curGrad = updates[i];
118            curGradsSquared.intersectAndAddInPlace(curGrad,square);
119            curGrad.hadamardProductInPlace(curGradsSquared,scale);
120        }
121
122        iteration++;
123        return updates;
124    }
125
126    @Override
127    public String toString() {
128        return "RMSProp(initialLearningRate="+initialLearningRate+",rho="+rho+",epsilon="+epsilon+",decay="+decay+")";
129    }
130
131    @Override
132    public void reset() {
133        gradsSquared = null;
134        iteration = 0;
135    }
136
137    @Override
138    public RMSProp copy() {
139        return new RMSProp(initialLearningRate,rho,epsilon,decay);
140    }
141
142    @Override
143    public ConfiguredObjectProvenance getProvenance() {
144        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
145    }
146}