001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.Tensor;
025
026import java.util.function.DoubleUnaryOperator;
027import java.util.logging.Logger;
028
029/**
030 * An implementation of single learning rate SGD and optionally momentum.
031 * <p>
032 * Has factory methods to generate constant learning rate, linear decay and sqrt decay variants.
033 * <p>
034 * See:
035 * <pre>
036 * Bottou L.
037 * "Large-Scale Machine Learning with Stochastic Gradient Descent"
038 * Proceedings of COMPSTAT, 2010.
039 * </pre>
040 * and for the momentum implementation:
041 * <pre>
042 * Shallue et al,
043 * "Measuring the Effects of Data Parallelism on Neural Network Training"
044 * 2018, Arxiv 1811.03600
045 * </pre>
046 */
047public abstract class SGD implements StochasticGradientOptimiser {
048    private static final Logger logger = Logger.getLogger(SGD.class.getName());
049
050    /**
051     * Momentum types.
052     */
053    public enum Momentum {
054        /**
055         * No momentum.
056         */
057        NONE,
058        /**
059         * Standard momentum.
060         */
061        STANDARD,
062        /**
063         * Nesterov momentum.
064         */
065        NESTEROV
066    }
067
068    @Config(mandatory = true,description="Initial learning rate.")
069    protected double initialLearningRate;
070
071    @Config(mandatory = true,description="Momentum type to use.")
072    protected Momentum useMomentum;
073
074    @Config(description="Momentum scaling factor.")
075    protected double rho = 0.0;
076
077    protected int iteration = 0;
078
079    private Tensor[] momentum;
080
081    SGD(double learningRate) {
082        this(learningRate,0.0,Momentum.NONE);
083    }
084
085    SGD(double learningRate, double rho, Momentum useMomentum) {
086        this.initialLearningRate = learningRate;
087        this.useMomentum = useMomentum;
088        this.rho = rho;
089    }
090
091    /**
092     * For olcut.
093     */
094    protected SGD() { }
095
096    @Override
097    public void initialise(Parameters parameters) {
098        if (useMomentum != Momentum.NONE) {
099            momentum = parameters.getEmptyCopy();
100        }
101    }
102
103    @Override
104    public Tensor[] step(Tensor[] updates, double weight) {
105        iteration++;
106        double learningRate = learningRate();
107        DoubleUnaryOperator learningRateFunc = (double a) -> a * learningRate * weight;
108        DoubleUnaryOperator nesterovFunc = (double a) -> a * learningRate * weight * rho;
109
110        /* Modelled after momentum as described in
111         * "Measuring the Effects of Data Parallelism on Neural Network Training"
112         * Shallue et al 2018, Arxiv 1811.03600
113         */
114        for (int i = 0; i < updates.length; i++) {
115            switch (useMomentum) {
116                case STANDARD:
117                    momentum[i].scaleInPlace(rho);
118                    momentum[i].intersectAndAddInPlace(updates[i]);
119                    updates[i].scaleInPlace(0.0);
120                    updates[i].intersectAndAddInPlace(momentum[i],learningRateFunc);
121                    break;
122                case NESTEROV:
123                    momentum[i].scaleInPlace(rho);
124                    momentum[i].intersectAndAddInPlace(updates[i]);
125                    updates[i].scaleInPlace(weight * learningRate);
126                    updates[i].intersectAndAddInPlace(momentum[i],nesterovFunc);
127                    break;
128                case NONE:
129                default:
130                    updates[i].scaleInPlace(weight * learningRate);
131                    break;
132            }
133        }
134
135        return updates;
136    }
137
138    /**
139     * Override to provide a function which calculates the learning rate.
140     * The only available information is the iteration count.
141     * @return The current learning rate.
142     */
143    public abstract double learningRate();
144
145    /**
146     * Override to specify the kind of SGD.
147     * @return A string representing the SGD type.
148     */
149    protected abstract String sgdType();
150
151    @Override
152    public String toString() {
153        switch (useMomentum) {
154            case STANDARD:
155                return "SGD+Momentum(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ",rho="+rho+")";
156            case NESTEROV:
157                return "SGD+NesterovMomentum(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ",rho="+rho+")";
158            default:
159                return "SGD(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ")";
160        }
161    }
162
163    @Override
164    public void reset() {
165        momentum = null;
166        iteration = 0;
167    }
168
169    @Override
170    public ConfiguredObjectProvenance getProvenance() {
171        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
172    }
173
174    /**
175     * Generates an SGD optimiser with a constant learning rate set to learningRate.
176     * @param learningRate The learning rate.
177     * @return A constant learning rate SGD.
178     */
179    public static SGD getSimpleSGD(double learningRate) {
180        return new SimpleSGD(learningRate);
181    }
182
183    /**
184     * Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum.
185     * @param learningRate The learning rate.
186     * @param rho The momentum drag constant.
187     * @param momentumType Momentum type.
188     * @return A constant learning rate SGD with momentum.
189     */
190    public static SGD getSimpleSGD(double learningRate, double rho, Momentum momentumType) {
191        return new SimpleSGD(learningRate, rho, momentumType);
192    }
193
194    /**
195     * Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate.
196     *
197     * The learning rate = initialLearningRate / iteration.
198     * @param learningRate The learning rate.
199     * @return A linear decay SGD.
200     */
201    public static SGD getLinearDecaySGD(double learningRate) {
202        return new LinearDecaySGD(learningRate);
203    }
204
205    /**
206     * Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum.
207     *
208     * The learning rate = initialLearningRate / iteration.
209     * @param learningRate The learning rate.
210     * @param rho The momentum drag constant.
211     * @param momentumType Momentum type.
212     * @return A linear decay SGD with momentum.
213     */
214    public static SGD getLinearDecaySGD(double learningRate, double rho, Momentum momentumType) {
215        return new LinearDecaySGD(learningRate, rho, momentumType);
216    }
217
218    /**
219     * Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate.
220     *
221     * The learning rate = initialLearningRate / sqrt(iteration).
222     * @param learningRate The learning rate.
223     * @return A sqrt decay SGD.
224     */
225    public static SGD getSqrtDecaySGD(double learningRate) {
226        return new SqrtDecaySGD(learningRate);
227    }
228
229    /**
230     * Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum.
231     *
232     * The learning rate = initialLearningRate / sqrt(iteration).
233     * @param learningRate The learning rate.
234     * @param rho The momentum drag constant.
235     * @param momentumType Momentum type.
236     * @return A sqrt decay SGD with momentum.
237     */
238    public static SGD getSqrtDecaySGD(double learningRate, double rho, Momentum momentumType) {
239        return new SqrtDecaySGD(learningRate, rho, momentumType);
240    }
241}
242
243final class SimpleSGD extends SGD {
244    public SimpleSGD(double learningRate) {
245        super(learningRate);
246    }
247
248    public SimpleSGD(double learningRate, double rho, Momentum momentumType) {
249        super(learningRate, rho, momentumType);
250    }
251
252    protected SimpleSGD() { }
253
254    @Override
255    public double learningRate() {
256        return initialLearningRate;
257    }
258
259    @Override
260    protected String sgdType() {
261        return "Constant";
262    }
263
264    @Override
265    public SimpleSGD copy() {
266        return new SimpleSGD(initialLearningRate,rho,useMomentum);
267    }
268}
269
270final class LinearDecaySGD extends SGD {
271    public LinearDecaySGD(double learningRate) {
272        super(learningRate);
273    }
274
275    public LinearDecaySGD(double learningRate, double rho, Momentum momentumType) {
276        super(learningRate, rho, momentumType);
277    }
278
279    protected LinearDecaySGD() { }
280
281    @Override
282    public double learningRate() {
283        return initialLearningRate / iteration;
284    }
285
286    @Override
287    protected String sgdType() {
288        return "LinearDecay";
289    }
290
291    @Override
292    public LinearDecaySGD copy() {
293        return new LinearDecaySGD(initialLearningRate,rho,useMomentum);
294    }
295}
296
297final class SqrtDecaySGD extends SGD {
298    public SqrtDecaySGD(double learningRate) {
299        super(learningRate);
300    }
301
302    public SqrtDecaySGD(double learningRate, double rho, Momentum momentumType) {
303        super(learningRate, rho, momentumType);
304    }
305
306    protected SqrtDecaySGD() { }
307
308    @Override
309    public double learningRate() {
310        return initialLearningRate / Math.sqrt(iteration);
311    }
312
313    @Override
314    protected String sgdType() {
315        return "SqrtDecay";
316    }
317
318    @Override
319    public SqrtDecaySGD copy() {
320        return new SqrtDecaySGD(initialLearningRate,rho,useMomentum);
321    }
322}