001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.math.StochasticGradientOptimiser;
022
023import java.util.logging.Logger;
024
025/**
026 * CLI options for configuring a gradient optimiser.
027 */
028public class GradientOptimiserOptions implements Options {
029    private static final Logger logger = Logger.getLogger(GradientOptimiserOptions.class.getName());
030
031    /**
032     * Type of the gradient optimisers available in CLIs.
033     */
034    public enum StochasticGradientOptimiserType {
035        ADADELTA,
036        ADAGRAD,
037        ADAGRADRDA,
038        ADAM,
039        PEGASOS,
040        RMSPROP,
041        CONSTANTSGD,
042        LINEARSGD,
043        SQRTSGD
044    }
045
046    @Option(longName = "sgo-type", usage = "Selects the gradient optimiser. Defaults to ADAGRAD.")
047    private StochasticGradientOptimiserType optimiserType = StochasticGradientOptimiserType.ADAGRAD;
048
049    @Option(longName = "sgo-learning-rate", usage = "Learning rate for AdaGrad, AdaGradRDA, Adam, Pegasos.")
050    public double learningRate = 0.18;
051
052    @Option(longName = "sgo-epsilon", usage = "Epsilon for AdaDelta, AdaGrad, AdaGradRDA, Adam.")
053    public double epsilon = 0.066;
054
055    @Option(longName = "sgo-rho", usage = "Rho for RMSProp, AdaDelta, SGD with Momentum.")
056    public double rho = 0.95;
057
058    @Option(longName = "sgo-lambda", usage = "Lambda for Pegasos.")
059    public double lambda = 1e-2;
060
061    @Option(longName="sgo-parameter-averaging",usage="Use parameter averaging.")
062    public boolean paramAve = false;
063
064    @Option(longName="sgo-momentum",usage="Use momentum in SGD.")
065    public SGD.Momentum momentum = SGD.Momentum.NONE;
066
067    /**
068     * Gets the configured gradient optimiser.
069     * @return The gradient optimiser.
070     */
071    public StochasticGradientOptimiser getOptimiser() {
072        StochasticGradientOptimiser sgo;
073        switch(optimiserType) {
074            case ADADELTA:
075                sgo = new AdaDelta(rho,epsilon);
076                break;
077            case ADAGRAD:
078                sgo = new AdaGrad(learningRate, epsilon);
079                break;
080            case ADAGRADRDA:
081                sgo = new AdaGradRDA(learningRate, epsilon);
082                break;
083            case ADAM:
084                sgo = new Adam(learningRate,epsilon);
085                break;
086            case PEGASOS:
087                sgo = new Pegasos(learningRate,lambda);
088                break;
089            case RMSPROP:
090                sgo = new RMSProp(learningRate,rho);
091                break;
092            case CONSTANTSGD:
093                sgo = SGD.getSimpleSGD(learningRate,rho,momentum);
094                break;
095            case LINEARSGD:
096                sgo = SGD.getLinearDecaySGD(learningRate,rho,momentum);
097                break;
098            case SQRTSGD:
099                sgo = SGD.getSqrtDecaySGD(learningRate,rho,momentum);
100                break;
101            default:
102                throw new IllegalArgumentException("Unhandled StochasticGradientOptimiser type: "+optimiserType);
103        }
104        if (paramAve) {
105            logger.info("Using parameter averaging");
106            return new ParameterAveraging(sgo);
107        } else {
108            return sgo;
109        }
110    }
111
112}