001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.Tensor;
025
026import java.util.function.DoubleUnaryOperator;
027import java.util.logging.Logger;
028
029/**
030 * An implementation of the AdaGrad gradient optimiser.
031 * <p>
032 * Creates one copy of the parameters to store learning rates.
033 * <p>
034 * See:
035 * <pre>
036 * Duchi, J., Hazan, E., and Singer, Y.
037 * "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization"
038 * Journal of Machine Learning Research, 2012, 2121-2159.
039 * </pre>
040 */
041public class AdaGrad implements StochasticGradientOptimiser {
042    private static final Logger logger = Logger.getLogger(AdaGrad.class.getName());
043
044    @Config(mandatory = true,description="Initial learning rate used to scale the gradients.")
045    private double initialLearningRate;
046
047    @Config(description="Epsilon for numerical stability around zero.")
048    private double epsilon = 1e-6;
049
050    @Config(description="Initial value for the gradient accumulator.")
051    private double initialValue = 0.0;
052
053    private Tensor[] gradsSquared;
054
055    public AdaGrad(double initialLearningRate, double epsilon) {
056        this.initialLearningRate = initialLearningRate;
057        this.epsilon = epsilon;
058    }
059
060    /**
061     * Sets epsilon to 1e-6.
062     * @param initialLearningRate The learning rate.
063     */
064    public AdaGrad(double initialLearningRate) {
065        this(initialLearningRate,1e-6);
066    }
067
068    /**
069     * For olcut.
070     */
071    private AdaGrad() { }
072
073    @Override
074    public void initialise(Parameters parameters) {
075        this.gradsSquared = parameters.getEmptyCopy();
076        if (initialValue != 0.0) {
077            for (Tensor t : gradsSquared) {
078                t.scalarAddInPlace(initialValue);
079            }
080        }
081    }
082
083    @Override
084    public Tensor[] step(Tensor[] updates, double weight) {
085        //lifting lambdas out of the for loop until JDK-8183316 is fixed.
086        DoubleUnaryOperator square = (double a) -> weight*weight*a*a;
087        DoubleUnaryOperator scale = (double a) -> weight * initialLearningRate / (epsilon + Math.sqrt(a));
088        for (int i = 0; i < updates.length; i++) {
089            Tensor curGradsSquared = gradsSquared[i];
090            Tensor curGrad = updates[i];
091            curGradsSquared.intersectAndAddInPlace(curGrad,square);
092            curGrad.hadamardProductInPlace(curGradsSquared,scale);
093        }
094
095        return updates;
096    }
097
098    @Override
099    public String toString() {
100        return "AdaGrad(initialLearningRate="+initialLearningRate+",epsilon="+epsilon+",initialValue="+initialValue+")";
101    }
102
103    @Override
104    public void reset() {
105        gradsSquared = null;
106    }
107
108    @Override
109    public AdaGrad copy() {
110        return new AdaGrad(initialLearningRate,epsilon);
111    }
112
113    @Override
114    public ConfiguredObjectProvenance getProvenance() {
115        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
116    }
117}