001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.Tensor;
025
026/**
027 * Averages the parameters across a gradient run.
028 * <p>
029 * Wraps an inner gradient optimiser. Only changes the values when {@link ParameterAveraging#finalise()} is called
030 * <p>
031 * See:
032 * <pre>
033 * Polyak BT, Juditsky AB
034 * "Acceleration of Stochastic Approximation by Averaging"
035 * SIAM Journal on Control and Optimization, 1992.
036 * </pre>
037 */
038public class ParameterAveraging implements StochasticGradientOptimiser {
039
040    @Config(mandatory = true,description="Inner optimiser to average parameters across.")
041    private StochasticGradientOptimiser optimiser;
042
043    private int iterations = 0;
044    private Tensor[] weights;
045    private Parameters parameters;
046
047    /**
048     * Adds parameter averaging around a gradient optimiser.
049     * @param optimiser The inner optimiser to use to scale the gradients.
050     */
051    public ParameterAveraging(StochasticGradientOptimiser optimiser) {
052        this.optimiser = optimiser;
053    }
054
055    /**
056     * For olcut.
057     */
058    private ParameterAveraging() { }
059
060    @Override
061    public void initialise(Parameters parameters) {
062        optimiser.initialise(parameters);
063        weights = parameters.getEmptyCopy();
064        this.parameters = parameters;
065    }
066
067    /**
068     * This passes the gradient update to the inner optimiser, then updates
069     * the average weight values.
070     * @param updates An array of gradients.
071     * @param weight The weight for the current gradients.
072     * @return The gradients from the inner optimiser.
073     */
074    @Override
075    public Tensor[] step(Tensor[] updates, double weight) {
076        iterations++;
077        Tensor[] output = optimiser.step(updates, weight);
078        for (int i = 0; i < output.length; i++) {
079            weights[i].intersectAndAddInPlace(output[i],(double a) -> a * iterations);
080        }
081        return output;
082    }
083
084    /**
085     * This sets the parameters to their average value.
086     */
087    @Override
088    public void finalise() {
089        Tensor[] tmp = parameters.get();
090        for (int i = 0; i < tmp.length; i++) {
091            tmp[i].intersectAndAddInPlace(weights[i],(double a) -> -a / iterations);
092        }
093    }
094
095    @Override
096    public String toString() {
097        return "ParameterAveraging(optimiser="+optimiser.toString()+")";
098    }
099
100    @Override
101    public void reset() {
102        optimiser.reset();
103        iterations = 0;
104        weights = null;
105    }
106
107    @Override
108    public ParameterAveraging copy() {
109        return new ParameterAveraging(optimiser.copy());
110    }
111
112    @Override
113    public ConfiguredObjectProvenance getProvenance() {
114        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
115    }
116}