001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.DenseMatrix; 025import org.tribuo.math.la.DenseVector; 026import org.tribuo.math.la.Tensor; 027import org.tribuo.math.optimisers.util.ShrinkingMatrix; 028import org.tribuo.math.optimisers.util.ShrinkingVector; 029 030import java.util.function.DoubleUnaryOperator; 031import java.util.logging.Logger; 032 033/** 034 * An implementation of the RMSProp gradient optimiser. 035 * <p> 036 * Creates one copy of the parameters to store learning rates. 037 * Follows the Keras implementation. 038 * <p> 039 * See: 040 * <pre> 041 * Tieleman, T. and Hinton, G. 042 * Lecture 6.5 - RMSProp, COURSERA: Neural Networks for Machine Learning. 043 * Technical report, 2012. 044 * </pre> 045 */ 046public class RMSProp implements StochasticGradientOptimiser { 047 private static final Logger logger = Logger.getLogger(RMSProp.class.getName()); 048 049 @Config(mandatory = true,description="Learning rate to scale the gradients by.") 050 private double initialLearningRate; 051 052 @Config(description="Momentum parameter.") 053 private double rho = 0.9; 054 055 @Config(description="Epsilon for numerical stability.") 056 private double epsilon = 1e-8; 057 058 @Config(description="Decay factor for the momentum.") 059 private double decay = 0.0; 060 061 private double invRho; 062 063 private int iteration = 0; 064 065 private Tensor[] gradsSquared; 066 067 private DoubleUnaryOperator square; 068 069 public RMSProp(double initialLearningRate, double rho, double epsilon, double decay) { 070 this.initialLearningRate = initialLearningRate; 071 this.rho = rho; 072 this.epsilon = epsilon; 073 this.decay = decay; 074 this.iteration = 0; 075 postConfig(); 076 } 077 078 public RMSProp(double initialLearningRate, double rho) { 079 this(initialLearningRate,rho,1e-8,0.0); 080 } 081 082 /** 083 * For olcut. 084 */ 085 private RMSProp() { } 086 087 /** 088 * Used by the OLCUT configuration system, and should not be called by external code. 089 */ 090 @Override 091 public void postConfig() { 092 this.invRho = 1.0 - rho; 093 this.square = (double a) -> invRho*a*a; 094 } 095 096 @Override 097 public void initialise(Parameters parameters) { 098 gradsSquared = parameters.getEmptyCopy(); 099 for (int i = 0; i < gradsSquared.length; i++) { 100 if (gradsSquared[i] instanceof DenseVector) { 101 gradsSquared[i] = new ShrinkingVector(((DenseVector) gradsSquared[i]), invRho, false); 102 } else if (gradsSquared[i] instanceof DenseMatrix) { 103 gradsSquared[i] = new ShrinkingMatrix(((DenseMatrix) gradsSquared[i]), invRho, false); 104 } else { 105 throw new IllegalStateException("Unknown Tensor subclass"); 106 } 107 } 108 } 109 110 @Override 111 public Tensor[] step(Tensor[] updates, double weight) { 112 double learningRate = initialLearningRate / (1 + decay * iteration); 113 //lifting lambdas out of the for loop until JDK-8183316 is fixed. 114 DoubleUnaryOperator scale = (double a) -> weight * learningRate / (epsilon + Math.sqrt(a)); 115 for (int i = 0; i < updates.length; i++) { 116 Tensor curGradsSquared = gradsSquared[i]; 117 Tensor curGrad = updates[i]; 118 curGradsSquared.intersectAndAddInPlace(curGrad,square); 119 curGrad.hadamardProductInPlace(curGradsSquared,scale); 120 } 121 122 iteration++; 123 return updates; 124 } 125 126 @Override 127 public String toString() { 128 return "RMSProp(initialLearningRate="+initialLearningRate+",rho="+rho+",epsilon="+epsilon+",decay="+decay+")"; 129 } 130 131 @Override 132 public void reset() { 133 gradsSquared = null; 134 iteration = 0; 135 } 136 137 @Override 138 public RMSProp copy() { 139 return new RMSProp(initialLearningRate,rho,epsilon,decay); 140 } 141 142 @Override 143 public ConfiguredObjectProvenance getProvenance() { 144 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 145 } 146}