001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.Tensor; 025 026import java.util.function.DoubleUnaryOperator; 027 028/** 029 * An implementation of the Adam gradient optimiser. 030 * <p> 031 * Creates two copies of the parameters to store learning rates. 032 * <p> 033 * See: 034 * <pre> 035 * Kingma, D., and Ba, J. 036 * "Adam: A Method for Stochastic Optimization" 037 * arXiv preprint arXiv:1412.6980, 2014. 038 * </pre> 039 */ 040public class Adam implements StochasticGradientOptimiser { 041 042 @Config(description="Learning rate to scale the gradients by.") 043 private double initialLearningRate = 0.001; 044 045 @Config(description="The beta one parameter.") 046 private double betaOne = 0.9; 047 048 @Config(description="The beta two parameter.") 049 private double betaTwo = 0.99; 050 051 @Config(description="Epsilon for numerical stability.") 052 private double epsilon = 1e-6; 053 054 private int iterations = 0; 055 private Tensor[] firstMoment; 056 private Tensor[] secondMoment; 057 058 /** 059 * It's highly recommended not to modify these parameters, use one of the 060 * other constructors. 061 * @param initialLearningRate The initial learning rate. 062 * @param betaOne The value of beta-one. 063 * @param betaTwo The value of beta-two. 064 * @param epsilon The epsilon value. 065 */ 066 public Adam(double initialLearningRate, double betaOne, double betaTwo, double epsilon) { 067 this.initialLearningRate = initialLearningRate; 068 this.betaOne = betaOne; 069 this.betaTwo = betaTwo; 070 this.epsilon = epsilon; 071 this.iterations = 0; 072 } 073 074 /** 075 * Sets betaOne to 0.9 and betaTwo to 0.999 076 * @param initialLearningRate The initial learning rate. 077 * @param epsilon The epsilon value. 078 */ 079 public Adam(double initialLearningRate, double epsilon) { 080 this(initialLearningRate,0.9,0.999,epsilon); 081 } 082 083 /** 084 * Sets initialLearningRate to 0.001, betaOne to 0.9, betaTwo to 0.999, epsilon to 1e-6. 085 * These are the parameters from the Adam paper. 086 */ 087 public Adam() { 088 this(0.001,0.9,0.999,1e-6); 089 } 090 091 @Override 092 public void initialise(Parameters parameters) { 093 firstMoment = parameters.getEmptyCopy(); 094 secondMoment = parameters.getEmptyCopy(); 095 iterations = 0; 096 } 097 098 @Override 099 public Tensor[] step(Tensor[] updates, double weight) { 100 iterations++; 101 102 double learningRate = initialLearningRate * Math.sqrt(1.0 - Math.pow(betaTwo,iterations)) / (1.0 - Math.pow(betaOne,iterations)); 103 //lifting lambdas out of the for loop until JDK-8183316 is fixed. 104 DoubleUnaryOperator scale = (double a) -> a * learningRate; 105 106 for (int i = 0; i < updates.length; i++) { 107 firstMoment[i].scaleInPlace(betaOne); 108 firstMoment[i].intersectAndAddInPlace(updates[i],(double a) -> a * (1.0 - betaOne)); 109 secondMoment[i].scaleInPlace(betaTwo); 110 secondMoment[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - betaTwo)); 111 updates[i].scaleInPlace(0.0); //scales everything to zero, but leaving the sparse presence 112 updates[i].intersectAndAddInPlace(firstMoment[i],scale); // add in the first moment 113 updates[i].hadamardProductInPlace(secondMoment[i],(double a) -> Math.sqrt(a) + epsilon); // scale by second moment 114 } 115 116 return updates; 117 } 118 119 @Override 120 public String toString() { 121 return "Adam(learningRate="+initialLearningRate+",betaOne="+betaOne+",betaTwo="+betaTwo+",epsilon="+epsilon+")"; 122 } 123 124 @Override 125 public void reset() { 126 firstMoment = null; 127 secondMoment = null; 128 iterations = 0; 129 } 130 131 @Override 132 public Adam copy() { 133 return new Adam(initialLearningRate,betaOne,betaTwo,epsilon); 134 } 135 136 @Override 137 public ConfiguredObjectProvenance getProvenance() { 138 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 139 } 140}