001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.Tensor;
025
026import java.util.function.DoubleUnaryOperator;
027
028/**
029 * An implementation of the Adam gradient optimiser.
030 * <p>
031 * Creates two copies of the parameters to store learning rates.
032 * <p>
033 * See:
034 * <pre>
035 * Kingma, D., and Ba, J.
036 * "Adam: A Method for Stochastic Optimization"
037 * arXiv preprint arXiv:1412.6980, 2014.
038 * </pre>
039 */
040public class Adam implements StochasticGradientOptimiser {
041
042    @Config(description="Learning rate to scale the gradients by.")
043    private double initialLearningRate = 0.001;
044
045    @Config(description="The beta one parameter.")
046    private double betaOne = 0.9;
047
048    @Config(description="The beta two parameter.")
049    private double betaTwo = 0.99;
050
051    @Config(description="Epsilon for numerical stability.")
052    private double epsilon = 1e-6;
053
054    private int iterations = 0;
055    private Tensor[] firstMoment;
056    private Tensor[] secondMoment;
057
058    /**
059     * It's highly recommended not to modify these parameters, use one of the
060     * other constructors.
061     * @param initialLearningRate The initial learning rate.
062     * @param betaOne The value of beta-one.
063     * @param betaTwo The value of beta-two.
064     * @param epsilon The epsilon value.
065     */
066    public Adam(double initialLearningRate, double betaOne, double betaTwo, double epsilon) {
067        this.initialLearningRate = initialLearningRate;
068        this.betaOne = betaOne;
069        this.betaTwo = betaTwo;
070        this.epsilon = epsilon;
071        this.iterations = 0;
072    }
073
074    /**
075     * Sets betaOne to 0.9 and betaTwo to 0.999
076     * @param initialLearningRate The initial learning rate.
077     * @param epsilon The epsilon value.
078     */
079    public Adam(double initialLearningRate, double epsilon) {
080        this(initialLearningRate,0.9,0.999,epsilon);
081    }
082
083    /**
084     * Sets initialLearningRate to 0.001, betaOne to 0.9, betaTwo to 0.999, epsilon to 1e-6.
085     * These are the parameters from the Adam paper.
086     */
087    public Adam() {
088        this(0.001,0.9,0.999,1e-6);
089    }
090
091    @Override
092    public void initialise(Parameters parameters) {
093        firstMoment = parameters.getEmptyCopy();
094        secondMoment = parameters.getEmptyCopy();
095        iterations = 0;
096    }
097
098    @Override
099    public Tensor[] step(Tensor[] updates, double weight) {
100        iterations++;
101
102        double learningRate = initialLearningRate * Math.sqrt(1.0 - Math.pow(betaTwo,iterations)) / (1.0 - Math.pow(betaOne,iterations));
103        //lifting lambdas out of the for loop until JDK-8183316 is fixed.
104        DoubleUnaryOperator scale = (double a) -> a * learningRate;
105
106        for (int i = 0; i < updates.length; i++) {
107            firstMoment[i].scaleInPlace(betaOne);
108            firstMoment[i].intersectAndAddInPlace(updates[i],(double a) -> a * (1.0 - betaOne));
109            secondMoment[i].scaleInPlace(betaTwo);
110            secondMoment[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - betaTwo));
111            updates[i].scaleInPlace(0.0); //scales everything to zero, but leaving the sparse presence
112            updates[i].intersectAndAddInPlace(firstMoment[i],scale); // add in the first moment
113            updates[i].hadamardProductInPlace(secondMoment[i],(double a) -> Math.sqrt(a) + epsilon); // scale by second moment
114        }
115
116        return updates;
117    }
118
119    @Override
120    public String toString() {
121        return "Adam(learningRate="+initialLearningRate+",betaOne="+betaOne+",betaTwo="+betaTwo+",epsilon="+epsilon+")";
122    }
123
124    @Override
125    public void reset() {
126        firstMoment = null;
127        secondMoment = null;
128        iterations = 0;
129    }
130
131    @Override
132    public Adam copy() {
133        return new Adam(initialLearningRate,betaOne,betaTwo,epsilon);
134    }
135
136    @Override
137    public ConfiguredObjectProvenance getProvenance() {
138        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
139    }
140}