001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.kernel;
018
019import com.oracle.labs.mlrg.olcut.config.ArgumentException;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import org.tribuo.Trainer;
022import org.tribuo.classification.ClassificationOptions;
023import org.tribuo.math.kernel.Kernel;
024import org.tribuo.math.kernel.Linear;
025import org.tribuo.math.kernel.Polynomial;
026import org.tribuo.math.kernel.RBF;
027import org.tribuo.math.kernel.Sigmoid;
028
029import java.util.logging.Logger;
030
031/**
032 * Options for using the KernelSVMTrainer.
033 * <p>
034 * See:
035 * <pre>
036 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A
037 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM"
038 * Mathematical Programming, 2011.
039 * </pre>
040 */
041public class KernelSVMOptions implements ClassificationOptions<KernelSVMTrainer> {
042    private static final Logger logger = Logger.getLogger(KernelSVMOptions.class.getName());
043
044    /**
045     * The kernel types.
046     */
047    public enum KernelEnum {LINEAR, POLYNOMIAL, SIGMOID, RBF}
048
049    @Option(longName = "kernel-intercept", usage = "Intercept in kernel function. Defaults to 1.0.")
050    public double kernelIntercept = 1.0;
051    @Option(longName = "kernel-degree", usage = "Degree in polynomial kernel function. Defaults to 1.0.")
052    public double kernelDegree = 1.0;
053    @Option(longName = "kernel-gamma", usage = "Gamma value in kernel function. Defaults to 1.0.")
054    public double kernelGamma = 1.0;
055    @Option(longName = "kernel-epochs", usage = "Number of SGD epochs. Defaults to 5.")
056    public int kernelEpochs = 5;
057    @Option(longName = "kernel-kernel", usage = "Kernel function. Defaults to LINEAR.")
058    public KernelEnum kernelKernel = KernelEnum.LINEAR; //TODO should the default be KernelEnum.RBF?
059    @Option(longName = "kernel-lambda", usage = "Lambda value in gradient optimisation. Defaults to 0.01.")
060    public double kernelLambda = 0.01;
061    @Option(longName = "kernel-logging-interval", usage = "Log the objective after <int> examples. Defaults to 100.")
062    public int kernelLoggingInterval = 100;
063    @Option(longName = "kernel-seed", usage = "Sets the random seed for the Kernel SVM.")
064    private long kernelSeed = Trainer.DEFAULT_SEED;
065
066    @Override
067    public KernelSVMTrainer getTrainer() {
068        logger.info("Configuring Kernel SVM Trainer");
069        Kernel kernelObj = null;
070        switch (kernelKernel) {
071            case LINEAR:
072                logger.info("Using a linear kernel");
073                kernelObj = new Linear();
074                break;
075            case POLYNOMIAL:
076                logger.info("Using a Polynomial kernel with gamma " + kernelGamma + ", intercept " + kernelIntercept + ", and degree " + kernelDegree);
077                kernelObj = new Polynomial(kernelGamma, kernelIntercept, kernelDegree);
078                break;
079            case RBF:
080                logger.info("Using an RBF kernel with gamma " + kernelGamma);
081                kernelObj = new RBF(kernelGamma);
082                break;
083            case SIGMOID:
084                logger.info("Using a tanh kernel with gamma " + kernelGamma + ", and intercept " + kernelIntercept);
085                kernelObj = new Sigmoid(kernelGamma, kernelIntercept);
086                break;
087            default:
088                logger.warning("Unknown kernel function " + kernelKernel);
089                throw new ArgumentException("kernel-kernel", "Unknown kernel function " + kernelKernel);
090        }
091        logger.info(String.format("Set logging interval to %d", kernelLoggingInterval));
092        return new KernelSVMTrainer(kernelObj, kernelLambda, kernelEpochs, kernelLoggingInterval, kernelSeed);
093    }
094}