001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.Tensor; 025 026/** 027 * An implementation of the AdaDelta gradient optimiser. 028 * <p> 029 * Creates two copies of the parameters to store learning rates. 030 * <p> 031 * See: 032 * <pre> 033 * Zeiler, MD. 034 * "ADADELTA: an Adaptive Learning Rate Method" 035 * arXiv preprint arXiv:1212.5701. 036 * </pre> 037 */ 038public class AdaDelta implements StochasticGradientOptimiser { 039 040 @Config(description="Momentum value.") 041 private double rho = 0.95; 042 043 @Config(description="Epsilon for numerical stability.") 044 private double epsilon = 1e-6; 045 046 private Tensor[] gradsSquared; 047 private Tensor[] velocitySquared; 048 049 /** 050 * It's recommended to keep rho at 0.95. 051 * @param rho The rho value. 052 * @param epsilon The epsilon value. 053 */ 054 public AdaDelta(double rho, double epsilon) { 055 this.rho = rho; 056 this.epsilon = epsilon; 057 } 058 059 /** 060 * Keeps rho at 0.95, passes through epsilon. 061 * @param epsilon The epsilon value. 062 */ 063 public AdaDelta(double epsilon) { 064 this(0.95,epsilon); 065 } 066 067 /** 068 * Sets rho to 0.95 and epsilon to 1e-6. 069 */ 070 public AdaDelta() { 071 this(0.95,1e-6); 072 } 073 074 @Override 075 public void initialise(Parameters parameters) { 076 gradsSquared = parameters.getEmptyCopy(); 077 velocitySquared = parameters.getEmptyCopy(); 078 } 079 080 @Override 081 public Tensor[] step(Tensor[] updates, double weight) { 082 for (int i = 0; i < updates.length; i++) { 083 gradsSquared[i].scaleInPlace(rho); 084 gradsSquared[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - rho)); 085 updates[i].hadamardProductInPlace(velocitySquared[i],(double a) -> Math.sqrt(a + epsilon)); 086 updates[i].hadamardProductInPlace(gradsSquared[i],(double a) -> 1.0 / (Math.sqrt(a + epsilon))); 087 velocitySquared[i].scaleInPlace(rho); 088 velocitySquared[i].intersectAndAddInPlace(updates[i],(double a) -> a * a * (1.0 - rho)); 089 } 090 091 return updates; 092 } 093 094 @Override 095 public String toString() { 096 return "AdaDelta(rho="+rho+",epsilon="+epsilon+")"; 097 } 098 099 @Override 100 public void reset() { 101 gradsSquared = null; 102 velocitySquared = null; 103 } 104 105 @Override 106 public AdaDelta copy() { 107 return new AdaDelta(rho,epsilon); 108 } 109 110 @Override 111 public ConfiguredObjectProvenance getProvenance() { 112 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 113 } 114}