001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.DenseMatrix;
025import org.tribuo.math.la.DenseVector;
026import org.tribuo.math.la.Matrix;
027import org.tribuo.math.la.MatrixIterator;
028import org.tribuo.math.la.MatrixTuple;
029import org.tribuo.math.la.SGDVector;
030import org.tribuo.math.la.Tensor;
031import org.tribuo.math.la.VectorIterator;
032import org.tribuo.math.la.VectorTuple;
033
034import java.util.Arrays;
035import java.util.function.DoubleUnaryOperator;
036import java.util.logging.Logger;
037
038/**
039 * An implementation of the AdaGrad gradient optimiser with regularized dual averaging.
040 * <p>
041 * This gradient optimiser rewrites all the {@link Tensor}s in the {@link Parameters}
042 * with {@link AdaGradRDATensor}. This means it keeps a different value in the {@link Tensor}
043 * to the one produced when you call get(), so it can correctly apply regularisation to the parameters.
044 * When {@link AdaGradRDA#finalise()} is called it rewrites the {@link Parameters} with standard dense {@link Tensor}s.
045 * Follows the implementation in Factorie.
046 * <p>
047 * See:
048 * <pre>
049 * Duchi, J., Hazan, E., and Singer, Y.
050 * "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization"
051 * Journal of Machine Learning Research, 2012, 2121-2159.
052 * </pre>
053 */
054public class AdaGradRDA implements StochasticGradientOptimiser {
055    private static final Logger logger = Logger.getLogger(AdaGradRDA.class.getName());
056
057    @Config(mandatory = true,description="Initial learning rate used to scale the gradients.")
058    private double initialLearningRate;
059
060    @Config(description="Epsilon for numerical stability around zero.")
061    private double epsilon = 1e-6;
062
063    @Config(description="l1 regularization penalty.")
064    private double l1 = 0;
065
066    @Config(description="l2 regularization penalty.")
067    private double l2 = 0;
068
069    @Config(description="Number of examples to scale the l1 and l2 penalties by.")
070    private int numExamples = 1;
071
072    private Parameters parameters;
073
074    public AdaGradRDA(double initialLearningRate, double epsilon, double l1, double l2, int numExamples) {
075        this.initialLearningRate = initialLearningRate;
076        this.epsilon = epsilon;
077        this.l1 = l1;
078        this.l2 = l2;
079        this.numExamples = numExamples;
080    }
081
082    public AdaGradRDA(double initialLearningRate, double epsilon) {
083        this(initialLearningRate,epsilon,0,0,1);
084    }
085
086    /**
087     * For olcut.
088     */
089    private AdaGradRDA() { }
090
091    @Override
092    public void initialise(Parameters parameters) {
093        this.parameters = parameters;
094        Tensor[] curParams = parameters.get();
095        Tensor[] newParams = new Tensor[curParams.length];
096        for (int i = 0; i < newParams.length; i++) {
097            if (curParams[i] instanceof DenseVector) {
098                newParams[i] = new AdaGradRDAVector(((DenseVector) curParams[i]), initialLearningRate, epsilon, l1 / numExamples, l2 / numExamples);
099            } else if (curParams[i] instanceof DenseMatrix) {
100                newParams[i] = new AdaGradRDAMatrix(((DenseMatrix) curParams[i]), initialLearningRate, epsilon, l1 / numExamples, l2 / numExamples);
101            } else {
102                throw new IllegalStateException("Unknown Tensor subclass");
103            }
104        }
105        parameters.set(newParams);
106    }
107
108    @Override
109    public Tensor[] step(Tensor[] updates, double weight) {
110        for (Tensor update : updates) {
111            update.scaleInPlace(weight);
112        }
113
114        return updates;
115    }
116
117    @Override
118    public void finalise() {
119        Tensor[] curParams = parameters.get();
120        Tensor[] newParams = new Tensor[curParams.length];
121        for (int i = 0; i < newParams.length; i++) {
122            if (curParams[i] instanceof AdaGradRDATensor) {
123                newParams[i] = ((AdaGradRDATensor) curParams[i]).convertToDense();
124            } else {
125                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with AdaGradRDA");
126            }
127        }
128        parameters.set(newParams);
129    }
130
131    @Override
132    public String toString() {
133        return "AdaGradRDA(initialLearningRate="+initialLearningRate+",epsilon="+epsilon+",l1="+l1+",l2="+l2+")";
134    }
135
136    @Override
137    public void reset() { }
138
139    @Override
140    public AdaGradRDA copy() {
141        return new AdaGradRDA(initialLearningRate,epsilon,l1,l2,numExamples);
142    }
143
144    @Override
145    public ConfiguredObjectProvenance getProvenance() {
146        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
147    }
148
149    /**
150     * An interface which tags a {@link Tensor} with a convertToDense method.
151     */
152    private static interface AdaGradRDATensor {
153        public Tensor convertToDense();
154
155        public static double truncate(double input, double threshold) {
156            if (input > threshold) {
157                return input - threshold;
158            } else if (input < -threshold) {
159                return input + threshold;
160            } else {
161                return 0.0;
162            }
163        }
164
165    }
166
167    /**
168     * A subclass of {@link DenseVector} which uses {@link AdaGradRDATensor#truncate(double, double)} to
169     * produce the values.
170     * <p>
171     * Be careful when modifying this or {@link DenseVector}.
172     */
173    private static class AdaGradRDAVector extends DenseVector implements AdaGradRDATensor {
174        private final double learningRate;
175        private final double epsilon;
176        private final double l1;
177        private final double l2;
178        private final double[] gradSquares;
179        private int iteration;
180
181        public AdaGradRDAVector(DenseVector v, double learningRate, double epsilon, double l1, double l2) {
182            super(v);
183            this.learningRate = learningRate;
184            this.epsilon = epsilon;
185            this.l1 = l1;
186            this.l2 = l2;
187            this.gradSquares = new double[v.size()];
188            this.iteration = 0;
189        }
190
191        private AdaGradRDAVector(double[] values, double learningRate, double epsilon, double l1, double l2, double[] gradSquares, int iteration) {
192            super(values);
193            this.learningRate = learningRate;
194            this.epsilon = epsilon;
195            this.l1 = l1;
196            this.l2 = l2;
197            this.gradSquares = gradSquares;
198            this.iteration = iteration;
199        }
200
201        @Override
202        public DenseVector convertToDense() {
203            return DenseVector.createDenseVector(toArray());
204        }
205
206        @Override
207        public AdaGradRDAVector copy() {
208            return new AdaGradRDAVector(Arrays.copyOf(elements,elements.length),learningRate,epsilon,l1,l2,Arrays.copyOf(gradSquares,gradSquares.length),iteration);
209        }
210
211        @Override
212        public double[] toArray() {
213            double[] newValues = new double[elements.length];
214            for (int i = 0; i < newValues.length; i++) {
215                newValues[i] = get(i);
216            }
217            return newValues;
218        }
219
220        @Override
221        public double get(int index) {
222            if (gradSquares[index] == 0.0) {
223                return elements[index];
224            } else {
225                double h = ((Math.sqrt(gradSquares[index]) + epsilon) / learningRate) + iteration * l2;
226                //double h = (1.0/learningRate) * (Math.sqrt(gradSquares[index]) + epsilon) + iteration*l2;
227                double rate = 1.0/h;
228                return rate * AdaGradRDATensor.truncate(elements[index], iteration*l1);
229            }
230        }
231
232        @Override
233        public double sum() {
234            double sum = 0.0;
235            for (int i = 0; i < elements.length; i++) {
236                sum += get(i);
237            }
238            return sum;
239        }
240
241        @Override
242        public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
243            iteration++;
244            SGDVector otherVec = (SGDVector) other;
245            for (VectorTuple tuple : otherVec) {
246                double update = f.applyAsDouble(tuple.value);
247                elements[tuple.index] += update;
248                gradSquares[tuple.index] += update*update;
249            }
250        }
251
252        @Override
253        public int indexOfMax() {
254            int index = 0;
255            double value = Double.NEGATIVE_INFINITY;
256            for (int i = 0; i < elements.length; i++) {
257                double tmp = get(i);
258                if (tmp > value) {
259                    index = i;
260                    value = tmp;
261                }
262            }
263            return index;
264        }
265
266        @Override
267        public double maxValue() {
268            double value = Double.NEGATIVE_INFINITY;
269            for (int i = 0; i < elements.length; i++) {
270                double tmp = get(i);
271                if (tmp > value) {
272                    value = tmp;
273                }
274            }
275            return value;
276        }
277
278        @Override
279        public double minValue() {
280            double value = Double.POSITIVE_INFINITY;
281            for (int i = 0; i < elements.length; i++) {
282                double tmp = get(i);
283                if (tmp < value) {
284                    value = tmp;
285                }
286            }
287            return value;
288        }
289
290        @Override
291        public double dot(SGDVector other) {
292            double score = 0.0;
293
294            for (VectorTuple tuple : other) {
295                score += get(tuple.index) * tuple.value;
296            }
297
298            return score;
299        }
300
301        @Override
302        public VectorIterator iterator() {
303            return new RDAVectorIterator(this);
304        }
305
306        private static class RDAVectorIterator implements VectorIterator {
307            private final AdaGradRDAVector vector;
308            private final VectorTuple tuple;
309            private int index;
310
311            public RDAVectorIterator(AdaGradRDAVector vector) {
312                this.vector = vector;
313                this.tuple = new VectorTuple();
314                this.index = 0;
315            }
316
317            @Override
318            public boolean hasNext() {
319                return index < vector.size();
320            }
321
322            @Override
323            public VectorTuple next() {
324                tuple.index = index;
325                tuple.value = vector.get(index);
326                index++;
327                return tuple;
328            }
329
330            @Override
331            public VectorTuple getReference() {
332                return tuple;
333            }
334        }
335    }
336
337    /**
338     * A subclass of {@link DenseMatrix} which uses {@link AdaGradRDATensor#truncate(double, double)} to
339     * produce the values.
340     * <p>
341     * Be careful when modifying this or {@link DenseMatrix}.
342     */
343    private static class AdaGradRDAMatrix extends DenseMatrix implements AdaGradRDATensor {
344        private final double learningRate;
345        private final double epsilon;
346        private final double l1;
347        private final double l2;
348        private final double[][] gradSquares;
349        private int iteration;
350
351        public AdaGradRDAMatrix(DenseMatrix v, double learningRate, double epsilon, double l1, double l2) {
352            super(v);
353            this.learningRate = learningRate;
354            this.epsilon = epsilon;
355            this.l1 = l1;
356            this.l2 = l2;
357            this.gradSquares = new double[v.getDimension1Size()][v.getDimension2Size()];
358            this.iteration = 0;
359        }
360
361        @Override
362        public DenseMatrix convertToDense() {
363            return new DenseMatrix(this);
364        }
365
366        @Override
367        public DenseVector leftMultiply(SGDVector input) {
368            if (input.size() == dim2) {
369                double[] output = new double[dim1];
370                for (VectorTuple tuple : input) {
371                    for (int i = 0; i < output.length; i++) {
372                        output[i] += get(i,tuple.index) * tuple.value;
373                    }
374                }
375
376                return DenseVector.createDenseVector(output);
377            } else {
378                throw new IllegalArgumentException("input.size() != dim2");
379            }
380        }
381
382        @Override
383        public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
384            if (other instanceof Matrix) {
385                Matrix otherMat = (Matrix) other;
386                if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
387                    for (MatrixTuple tuple : otherMat) {
388                        double update = f.applyAsDouble(tuple.value);
389                        values[tuple.i][tuple.j] += update;
390                        gradSquares[tuple.i][tuple.j] += update*update;
391                    }
392                } else {
393                    throw new IllegalStateException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")");
394                }
395            } else {
396                throw new IllegalStateException("Adding a non-Matrix to a Matrix");
397            }
398        }
399
400        @Override
401        public double get(int i, int j) {
402            if (gradSquares[i][j] == 0.0) {
403                return values[i][j];
404            } else {
405                double h = ((Math.sqrt(gradSquares[i][j]) + epsilon) / learningRate) + iteration * l2;
406                //double h = (1.0/learningRate) * (Math.sqrt(gradSquares[index]) + epsilon) + iteration*l2;
407                double rate = 1.0/h;
408                return rate * AdaGradRDATensor.truncate(values[i][j], iteration*l1);
409            }
410        }
411
412        @Override
413        public MatrixIterator iterator() {
414            return new RDAMatrixIterator(this);
415        }
416
417        private static class RDAMatrixIterator implements MatrixIterator {
418            private final AdaGradRDAMatrix matrix;
419            private final MatrixTuple tuple;
420            private final int dim2;
421            private int i;
422            private int j;
423
424            public RDAMatrixIterator(AdaGradRDAMatrix matrix) {
425                this.matrix = matrix;
426                this.tuple = new MatrixTuple();
427                this.dim2 = matrix.dim2;
428                this.i = 0;
429                this.j = 0;
430            }
431
432            @Override
433            public MatrixTuple getReference() {
434                return tuple;
435            }
436
437            @Override
438            public boolean hasNext() {
439                return (i < matrix.dim1) && (j < matrix.dim2);
440            }
441
442            @Override
443            public MatrixTuple next() {
444                tuple.i = i;
445                tuple.j = j;
446                tuple.value = matrix.get(i,j);
447                if (j < dim2-1) {
448                    j++;
449                } else {
450                    //Reached end of current vector, get next one
451                    i++;
452                    j = 0;
453                }
454                return tuple;
455            }
456        }
457
458    }
459}
460