001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.Tensor;
025
026/**
027 * An implementation of the AdaDelta gradient optimiser.
028 * <p>
029 * Creates two copies of the parameters to store learning rates.
030 * <p>
031 * See:
032 * <pre>
033 * Zeiler, MD.
034 * "ADADELTA: an Adaptive Learning Rate Method"
035 * arXiv preprint arXiv:1212.5701.
036 * </pre>
037 */
038public class AdaDelta implements StochasticGradientOptimiser {
039
040    @Config(description="Momentum value.")
041    private double rho = 0.95;
042
043    @Config(description="Epsilon for numerical stability.")
044    private double epsilon = 1e-6;
045
046    private Tensor[] gradsSquared;
047    private Tensor[] velocitySquared;
048
049    /**
050     * It's recommended to keep rho at 0.95.
051     * @param rho The rho value.
052     * @param epsilon The epsilon value.
053     */
054    public AdaDelta(double rho, double epsilon) {
055        this.rho = rho;
056        this.epsilon = epsilon;
057    }
058
059    /**
060     * Keeps rho at 0.95, passes through epsilon.
061     * @param epsilon The epsilon value.
062     */
063    public AdaDelta(double epsilon) {
064        this(0.95,epsilon);
065    }
066
067    /**
068     * Sets rho to 0.95 and epsilon to 1e-6.
069     */
070    public AdaDelta() {
071        this(0.95,1e-6);
072    }
073
074    @Override
075    public void initialise(Parameters parameters) {
076        gradsSquared = parameters.getEmptyCopy();
077        velocitySquared = parameters.getEmptyCopy();
078    }
079
080    @Override
081    public Tensor[] step(Tensor[] updates, double weight) {
082        for (int i = 0; i < updates.length; i++) {
083            gradsSquared[i].scaleInPlace(rho);
084            gradsSquared[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - rho));
085            updates[i].hadamardProductInPlace(velocitySquared[i],(double a) -> Math.sqrt(a + epsilon));
086            updates[i].hadamardProductInPlace(gradsSquared[i],(double a) -> 1.0 / (Math.sqrt(a + epsilon)));
087            velocitySquared[i].scaleInPlace(rho);
088            velocitySquared[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - rho));
089        }
090
091        return updates;
092    }
093
094    @Override
095    public String toString() {
096        return "AdaDelta(rho="+rho+",epsilon="+epsilon+")";
097    }
098
099    @Override
100    public void reset() {
101        gradsSquared = null;
102        velocitySquared = null;
103    }
104
105    @Override
106    public AdaDelta copy() {
107        return new AdaDelta(rho,epsilon);
108    }
109
110    @Override
111    public ConfiguredObjectProvenance getProvenance() {
112        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
113    }
114}