001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.Tensor; 025 026import java.util.function.DoubleUnaryOperator; 027import java.util.logging.Logger; 028 029/** 030 * An implementation of single learning rate SGD and optionally momentum. 031 * <p> 032 * Has factory methods to generate constant learning rate, linear decay and sqrt decay variants. 033 * <p> 034 * See: 035 * <pre> 036 * Bottou L. 037 * "Large-Scale Machine Learning with Stochastic Gradient Descent" 038 * Proceedings of COMPSTAT, 2010. 039 * </pre> 040 * and for the momentum implementation: 041 * <pre> 042 * Shallue et al, 043 * "Measuring the Effects of Data Parallelism on Neural Network Training" 044 * 2018, Arxiv 1811.03600 045 * </pre> 046 */ 047public abstract class SGD implements StochasticGradientOptimiser { 048 private static final Logger logger = Logger.getLogger(SGD.class.getName()); 049 050 /** 051 * Momentum types. 052 */ 053 public enum Momentum { 054 /** 055 * No momentum. 056 */ 057 NONE, 058 /** 059 * Standard momentum. 060 */ 061 STANDARD, 062 /** 063 * Nesterov momentum. 064 */ 065 NESTEROV 066 } 067 068 @Config(mandatory = true,description="Initial learning rate.") 069 protected double initialLearningRate; 070 071 @Config(mandatory = true,description="Momentum type to use.") 072 protected Momentum useMomentum; 073 074 @Config(description="Momentum scaling factor.") 075 protected double rho = 0.0; 076 077 protected int iteration = 0; 078 079 private Tensor[] momentum; 080 081 SGD(double learningRate) { 082 this(learningRate,0.0,Momentum.NONE); 083 } 084 085 SGD(double learningRate, double rho, Momentum useMomentum) { 086 this.initialLearningRate = learningRate; 087 this.useMomentum = useMomentum; 088 this.rho = rho; 089 } 090 091 /** 092 * For olcut. 093 */ 094 protected SGD() { } 095 096 @Override 097 public void initialise(Parameters parameters) { 098 if (useMomentum != Momentum.NONE) { 099 momentum = parameters.getEmptyCopy(); 100 } 101 } 102 103 @Override 104 public Tensor[] step(Tensor[] updates, double weight) { 105 iteration++; 106 double learningRate = learningRate(); 107 DoubleUnaryOperator learningRateFunc = (double a) -> a * learningRate * weight; 108 DoubleUnaryOperator nesterovFunc = (double a) -> a * learningRate * weight * rho; 109 110 /* Modelled after momentum as described in 111 * "Measuring the Effects of Data Parallelism on Neural Network Training" 112 * Shallue et al 2018, Arxiv 1811.03600 113 */ 114 for (int i = 0; i < updates.length; i++) { 115 switch (useMomentum) { 116 case STANDARD: 117 momentum[i].scaleInPlace(rho); 118 momentum[i].intersectAndAddInPlace(updates[i]); 119 updates[i].scaleInPlace(0.0); 120 updates[i].intersectAndAddInPlace(momentum[i],learningRateFunc); 121 break; 122 case NESTEROV: 123 momentum[i].scaleInPlace(rho); 124 momentum[i].intersectAndAddInPlace(updates[i]); 125 updates[i].scaleInPlace(weight * learningRate); 126 updates[i].intersectAndAddInPlace(momentum[i],nesterovFunc); 127 break; 128 case NONE: 129 default: 130 updates[i].scaleInPlace(weight * learningRate); 131 break; 132 } 133 } 134 135 return updates; 136 } 137 138 /** 139 * Override to provide a function which calculates the learning rate. 140 * The only available information is the iteration count. 141 * @return The current learning rate. 142 */ 143 public abstract double learningRate(); 144 145 /** 146 * Override to specify the kind of SGD. 147 * @return A string representing the SGD type. 148 */ 149 protected abstract String sgdType(); 150 151 @Override 152 public String toString() { 153 switch (useMomentum) { 154 case STANDARD: 155 return "SGD+Momentum(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ",rho="+rho+")"; 156 case NESTEROV: 157 return "SGD+NesterovMomentum(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ",rho="+rho+")"; 158 default: 159 return "SGD(type=" + sgdType() + ",initialLearningRate=" + initialLearningRate + ")"; 160 } 161 } 162 163 @Override 164 public void reset() { 165 momentum = null; 166 iteration = 0; 167 } 168 169 @Override 170 public ConfiguredObjectProvenance getProvenance() { 171 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 172 } 173 174 /** 175 * Generates an SGD optimiser with a constant learning rate set to learningRate. 176 * @param learningRate The learning rate. 177 * @return A constant learning rate SGD. 178 */ 179 public static SGD getSimpleSGD(double learningRate) { 180 return new SimpleSGD(learningRate); 181 } 182 183 /** 184 * Generates an SGD optimiser with a constant learning rate set to learningRate, with momentum. 185 * @param learningRate The learning rate. 186 * @param rho The momentum drag constant. 187 * @param momentumType Momentum type. 188 * @return A constant learning rate SGD with momentum. 189 */ 190 public static SGD getSimpleSGD(double learningRate, double rho, Momentum momentumType) { 191 return new SimpleSGD(learningRate, rho, momentumType); 192 } 193 194 /** 195 * Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate. 196 * 197 * The learning rate = initialLearningRate / iteration. 198 * @param learningRate The learning rate. 199 * @return A linear decay SGD. 200 */ 201 public static SGD getLinearDecaySGD(double learningRate) { 202 return new LinearDecaySGD(learningRate); 203 } 204 205 /** 206 * Generates an SGD optimiser with a linearly decaying learning rate initialised to learningRate, with momentum. 207 * 208 * The learning rate = initialLearningRate / iteration. 209 * @param learningRate The learning rate. 210 * @param rho The momentum drag constant. 211 * @param momentumType Momentum type. 212 * @return A linear decay SGD with momentum. 213 */ 214 public static SGD getLinearDecaySGD(double learningRate, double rho, Momentum momentumType) { 215 return new LinearDecaySGD(learningRate, rho, momentumType); 216 } 217 218 /** 219 * Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate. 220 * 221 * The learning rate = initialLearningRate / sqrt(iteration). 222 * @param learningRate The learning rate. 223 * @return A sqrt decay SGD. 224 */ 225 public static SGD getSqrtDecaySGD(double learningRate) { 226 return new SqrtDecaySGD(learningRate); 227 } 228 229 /** 230 * Generates an SGD optimiser with a sqrt decaying learning rate initialised to learningRate, with momentum. 231 * 232 * The learning rate = initialLearningRate / sqrt(iteration). 233 * @param learningRate The learning rate. 234 * @param rho The momentum drag constant. 235 * @param momentumType Momentum type. 236 * @return A sqrt decay SGD with momentum. 237 */ 238 public static SGD getSqrtDecaySGD(double learningRate, double rho, Momentum momentumType) { 239 return new SqrtDecaySGD(learningRate, rho, momentumType); 240 } 241} 242 243final class SimpleSGD extends SGD { 244 public SimpleSGD(double learningRate) { 245 super(learningRate); 246 } 247 248 public SimpleSGD(double learningRate, double rho, Momentum momentumType) { 249 super(learningRate, rho, momentumType); 250 } 251 252 protected SimpleSGD() { } 253 254 @Override 255 public double learningRate() { 256 return initialLearningRate; 257 } 258 259 @Override 260 protected String sgdType() { 261 return "Constant"; 262 } 263 264 @Override 265 public SimpleSGD copy() { 266 return new SimpleSGD(initialLearningRate,rho,useMomentum); 267 } 268} 269 270final class LinearDecaySGD extends SGD { 271 public LinearDecaySGD(double learningRate) { 272 super(learningRate); 273 } 274 275 public LinearDecaySGD(double learningRate, double rho, Momentum momentumType) { 276 super(learningRate, rho, momentumType); 277 } 278 279 protected LinearDecaySGD() { } 280 281 @Override 282 public double learningRate() { 283 return initialLearningRate / iteration; 284 } 285 286 @Override 287 protected String sgdType() { 288 return "LinearDecay"; 289 } 290 291 @Override 292 public LinearDecaySGD copy() { 293 return new LinearDecaySGD(initialLearningRate,rho,useMomentum); 294 } 295} 296 297final class SqrtDecaySGD extends SGD { 298 public SqrtDecaySGD(double learningRate) { 299 super(learningRate); 300 } 301 302 public SqrtDecaySGD(double learningRate, double rho, Momentum momentumType) { 303 super(learningRate, rho, momentumType); 304 } 305 306 protected SqrtDecaySGD() { } 307 308 @Override 309 public double learningRate() { 310 return initialLearningRate / Math.sqrt(iteration); 311 } 312 313 @Override 314 protected String sgdType() { 315 return "SqrtDecay"; 316 } 317 318 @Override 319 public SqrtDecaySGD copy() { 320 return new SqrtDecaySGD(initialLearningRate,rho,useMomentum); 321 } 322}