001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import org.tribuo.math.StochasticGradientOptimiser; 022 023import java.util.logging.Logger; 024 025/** 026 * CLI options for configuring a gradient optimiser. 027 */ 028public class GradientOptimiserOptions implements Options { 029 private static final Logger logger = Logger.getLogger(GradientOptimiserOptions.class.getName()); 030 031 /** 032 * Type of the gradient optimisers available in CLIs. 033 */ 034 public enum StochasticGradientOptimiserType { 035 ADADELTA, 036 ADAGRAD, 037 ADAGRADRDA, 038 ADAM, 039 PEGASOS, 040 RMSPROP, 041 CONSTANTSGD, 042 LINEARSGD, 043 SQRTSGD 044 } 045 046 @Option(longName = "sgo-type", usage = "Selects the gradient optimiser. Defaults to ADAGRAD.") 047 private StochasticGradientOptimiserType optimiserType = StochasticGradientOptimiserType.ADAGRAD; 048 049 @Option(longName = "sgo-learning-rate", usage = "Learning rate for AdaGrad, AdaGradRDA, Adam, Pegasos.") 050 public double learningRate = 0.18; 051 052 @Option(longName = "sgo-epsilon", usage = "Epsilon for AdaDelta, AdaGrad, AdaGradRDA, Adam.") 053 public double epsilon = 0.066; 054 055 @Option(longName = "sgo-rho", usage = "Rho for RMSProp, AdaDelta, SGD with Momentum.") 056 public double rho = 0.95; 057 058 @Option(longName = "sgo-lambda", usage = "Lambda for Pegasos.") 059 public double lambda = 1e-2; 060 061 @Option(longName="sgo-parameter-averaging",usage="Use parameter averaging.") 062 public boolean paramAve = false; 063 064 @Option(longName="sgo-momentum",usage="Use momentum in SGD.") 065 public SGD.Momentum momentum = SGD.Momentum.NONE; 066 067 /** 068 * Gets the configured gradient optimiser. 069 * @return The gradient optimiser. 070 */ 071 public StochasticGradientOptimiser getOptimiser() { 072 StochasticGradientOptimiser sgo; 073 switch(optimiserType) { 074 case ADADELTA: 075 sgo = new AdaDelta(rho,epsilon); 076 break; 077 case ADAGRAD: 078 sgo = new AdaGrad(learningRate, epsilon); 079 break; 080 case ADAGRADRDA: 081 sgo = new AdaGradRDA(learningRate, epsilon); 082 break; 083 case ADAM: 084 sgo = new Adam(learningRate,epsilon); 085 break; 086 case PEGASOS: 087 sgo = new Pegasos(learningRate,lambda); 088 break; 089 case RMSPROP: 090 sgo = new RMSProp(learningRate,rho); 091 break; 092 case CONSTANTSGD: 093 sgo = SGD.getSimpleSGD(learningRate,rho,momentum); 094 break; 095 case LINEARSGD: 096 sgo = SGD.getLinearDecaySGD(learningRate,rho,momentum); 097 break; 098 case SQRTSGD: 099 sgo = SGD.getSqrtDecaySGD(learningRate,rho,momentum); 100 break; 101 default: 102 throw new IllegalArgumentException("Unhandled StochasticGradientOptimiser type: "+optimiserType); 103 } 104 if (paramAve) { 105 logger.info("Using parameter averaging"); 106 return new ParameterAveraging(sgo); 107 } else { 108 return sgo; 109 } 110 } 111 112}