001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.kernel; 018 019import com.oracle.labs.mlrg.olcut.config.ArgumentException; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import org.tribuo.Trainer; 022import org.tribuo.classification.ClassificationOptions; 023import org.tribuo.math.kernel.Kernel; 024import org.tribuo.math.kernel.Linear; 025import org.tribuo.math.kernel.Polynomial; 026import org.tribuo.math.kernel.RBF; 027import org.tribuo.math.kernel.Sigmoid; 028 029import java.util.logging.Logger; 030 031/** 032 * Options for using the KernelSVMTrainer. 033 * <p> 034 * See: 035 * <pre> 036 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A 037 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM" 038 * Mathematical Programming, 2011. 039 * </pre> 040 */ 041public class KernelSVMOptions implements ClassificationOptions<KernelSVMTrainer> { 042 private static final Logger logger = Logger.getLogger(KernelSVMOptions.class.getName()); 043 044 /** 045 * The kernel types. 046 */ 047 public enum KernelEnum {LINEAR, POLYNOMIAL, SIGMOID, RBF} 048 049 @Option(longName = "kernel-intercept", usage = "Intercept in kernel function. Defaults to 1.0.") 050 public double kernelIntercept = 1.0; 051 @Option(longName = "kernel-degree", usage = "Degree in polynomial kernel function. Defaults to 1.0.") 052 public double kernelDegree = 1.0; 053 @Option(longName = "kernel-gamma", usage = "Gamma value in kernel function. Defaults to 1.0.") 054 public double kernelGamma = 1.0; 055 @Option(longName = "kernel-epochs", usage = "Number of SGD epochs. Defaults to 5.") 056 public int kernelEpochs = 5; 057 @Option(longName = "kernel-kernel", usage = "Kernel function. Defaults to LINEAR.") 058 public KernelEnum kernelKernel = KernelEnum.LINEAR; //TODO should the default be KernelEnum.RBF? 059 @Option(longName = "kernel-lambda", usage = "Lambda value in gradient optimisation. Defaults to 0.01.") 060 public double kernelLambda = 0.01; 061 @Option(longName = "kernel-logging-interval", usage = "Log the objective after <int> examples. Defaults to 100.") 062 public int kernelLoggingInterval = 100; 063 @Option(longName = "kernel-seed", usage = "Sets the random seed for the Kernel SVM.") 064 private long kernelSeed = Trainer.DEFAULT_SEED; 065 066 @Override 067 public KernelSVMTrainer getTrainer() { 068 logger.info("Configuring Kernel SVM Trainer"); 069 Kernel kernelObj = null; 070 switch (kernelKernel) { 071 case LINEAR: 072 logger.info("Using a linear kernel"); 073 kernelObj = new Linear(); 074 break; 075 case POLYNOMIAL: 076 logger.info("Using a Polynomial kernel with gamma " + kernelGamma + ", intercept " + kernelIntercept + ", and degree " + kernelDegree); 077 kernelObj = new Polynomial(kernelGamma, kernelIntercept, kernelDegree); 078 break; 079 case RBF: 080 logger.info("Using an RBF kernel with gamma " + kernelGamma); 081 kernelObj = new RBF(kernelGamma); 082 break; 083 case SIGMOID: 084 logger.info("Using a tanh kernel with gamma " + kernelGamma + ", and intercept " + kernelIntercept); 085 kernelObj = new Sigmoid(kernelGamma, kernelIntercept); 086 break; 087 default: 088 logger.warning("Unknown kernel function " + kernelKernel); 089 throw new ArgumentException("kernel-kernel", "Unknown kernel function " + kernelKernel); 090 } 091 logger.info(String.format("Set logging interval to %d", kernelLoggingInterval)); 092 return new KernelSVMTrainer(kernelObj, kernelLambda, kernelEpochs, kernelLoggingInterval, kernelSeed); 093 } 094}