001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022import org.tribuo.math.la.Tensor; 023import org.tribuo.math.optimisers.ParameterAveraging; 024 025/** 026 * Interface for gradient based optimisation methods. 027 * <p> 028 * Order of use: 029 * <ul> 030 * <li>{@link StochasticGradientOptimiser#initialise(Parameters)}</li> 031 * <li>take many {@link StochasticGradientOptimiser#step(Tensor[], double)}s</li> 032 * <li>{@link StochasticGradientOptimiser#finalise()}</li> 033 * <li>{@link StochasticGradientOptimiser#reset()}</li> 034 * </ul> 035 * 036 * Deviating from this order will cause unexpected behaviour. 037 */ 038public interface StochasticGradientOptimiser extends Configurable, Provenancable<ConfiguredObjectProvenance> { 039 040 /** 041 * Initialises the gradient optimiser. 042 * <p> 043 * Configures any learning rate parameters. 044 * @param parameters The parameters to optimise. 045 */ 046 default public void initialise(Parameters parameters) {} 047 048 /** 049 * Take a {@link Tensor} array of gradients and transform them 050 * according to the current weight and learning rates. 051 * <p> 052 * Can return the same {@link Tensor} array or a new one. 053 * @param updates An array of gradients. 054 * @param weight The weight for the current gradients. 055 * @return A {@link Tensor} array of gradients. 056 */ 057 public Tensor[] step(Tensor[] updates, double weight); 058 059 /** 060 * Finalises the gradient optimisation, setting the parameters to their correct values. 061 * Used for {@link ParameterAveraging} amongst others. 062 */ 063 default public void finalise() {} 064 065 /** 066 * Resets the optimiser so it's ready to optimise a new {@link Parameters}. 067 */ 068 public void reset(); 069 070 /** 071 * Copies a gradient optimiser with it's configuration. Usually calls the copy constructor. 072 * @return A gradient optimiser with the same configuration, but independent state. 073 */ 074 public StochasticGradientOptimiser copy(); 075}