001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.Tensor; 025 026/** 027 * Averages the parameters across a gradient run. 028 * <p> 029 * Wraps an inner gradient optimiser. Only changes the values when {@link ParameterAveraging#finalise()} is called 030 * <p> 031 * See: 032 * <pre> 033 * Polyak BT, Juditsky AB 034 * "Acceleration of Stochastic Approximation by Averaging" 035 * SIAM Journal on Control and Optimization, 1992. 036 * </pre> 037 */ 038public class ParameterAveraging implements StochasticGradientOptimiser { 039 040 @Config(mandatory = true,description="Inner optimiser to average parameters across.") 041 private StochasticGradientOptimiser optimiser; 042 043 private int iterations = 0; 044 private Tensor[] weights; 045 private Parameters parameters; 046 047 /** 048 * Adds parameter averaging around a gradient optimiser. 049 * @param optimiser The inner optimiser to use to scale the gradients. 050 */ 051 public ParameterAveraging(StochasticGradientOptimiser optimiser) { 052 this.optimiser = optimiser; 053 } 054 055 /** 056 * For olcut. 057 */ 058 private ParameterAveraging() { } 059 060 @Override 061 public void initialise(Parameters parameters) { 062 optimiser.initialise(parameters); 063 weights = parameters.getEmptyCopy(); 064 this.parameters = parameters; 065 } 066 067 /** 068 * This passes the gradient update to the inner optimiser, then updates 069 * the average weight values. 070 * @param updates An array of gradients. 071 * @param weight The weight for the current gradients. 072 * @return The gradients from the inner optimiser. 073 */ 074 @Override 075 public Tensor[] step(Tensor[] updates, double weight) { 076 iterations++; 077 Tensor[] output = optimiser.step(updates, weight); 078 for (int i = 0; i < output.length; i++) { 079 weights[i].intersectAndAddInPlace(output[i],(double a) -> a * iterations); 080 } 081 return output; 082 } 083 084 /** 085 * This sets the parameters to their average value. 086 */ 087 @Override 088 public void finalise() { 089 Tensor[] tmp = parameters.get(); 090 for (int i = 0; i < tmp.length; i++) { 091 tmp[i].intersectAndAddInPlace(weights[i],(double a) -> -a / iterations); 092 } 093 } 094 095 @Override 096 public String toString() { 097 return "ParameterAveraging(optimiser="+optimiser.toString()+")"; 098 } 099 100 @Override 101 public void reset() { 102 optimiser.reset(); 103 iterations = 0; 104 weights = null; 105 } 106 107 @Override 108 public ParameterAveraging copy() { 109 return new ParameterAveraging(optimiser.copy()); 110 } 111 112 @Override 113 public ConfiguredObjectProvenance getProvenance() { 114 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 115 } 116}