001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math;
018
019import com.oracle.labs.mlrg.olcut.config.Configurable;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
022import org.tribuo.math.la.Tensor;
023import org.tribuo.math.optimisers.ParameterAveraging;
024
025/**
026 * Interface for gradient based optimisation methods.
027 * <p>
028 * Order of use:
029 * <ul>
030 * <li>{@link StochasticGradientOptimiser#initialise(Parameters)}</li>
031 * <li>take many {@link StochasticGradientOptimiser#step(Tensor[], double)}s</li>
032 * <li>{@link StochasticGradientOptimiser#finalise()}</li>
033 * <li>{@link StochasticGradientOptimiser#reset()}</li>
034 * </ul>
035 *
036 * Deviating from this order will cause unexpected behaviour.
037 */
038public interface StochasticGradientOptimiser extends Configurable, Provenancable<ConfiguredObjectProvenance> {
039
040    /**
041     * Initialises the gradient optimiser.
042     * <p>
043     * Configures any learning rate parameters.
044     * @param parameters The parameters to optimise.
045     */
046    default public void initialise(Parameters parameters) {}
047
048    /**
049     * Take a {@link Tensor} array of gradients and transform them
050     * according to the current weight and learning rates.
051     * <p>
052     * Can return the same {@link Tensor} array or a new one.
053     * @param updates An array of gradients.
054     * @param weight The weight for the current gradients.
055     * @return A {@link Tensor} array of gradients.
056     */
057    public Tensor[] step(Tensor[] updates, double weight);
058
059    /**
060     * Finalises the gradient optimisation, setting the parameters to their correct values.
061     * Used for {@link ParameterAveraging} amongst others.
062     */
063    default public void finalise() {}
064
065    /**
066     * Resets the optimiser so it's ready to optimise a new {@link Parameters}.
067     */
068    public void reset();
069
070    /**
071     * Copies a gradient optimiser with it's configuration. Usually calls the copy constructor.
072     * @return A gradient optimiser with the same configuration, but independent state.
073     */
074    public StochasticGradientOptimiser copy();
075}