001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.libsvm;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import org.tribuo.classification.ClassificationOptions;
021import org.tribuo.classification.Label;
022import org.tribuo.classification.libsvm.SVMClassificationType.SVMMode;
023import org.tribuo.common.libsvm.KernelType;
024import org.tribuo.common.libsvm.SVMParameters;
025
026/**
027 * CLI options for training a LibSVM classification model.
028 */
029public class LibSVMOptions implements ClassificationOptions<LibSVMClassificationTrainer> {
030
031    @Override
032    public String getOptionsDescription() {
033        return "Options for parameterising a LibSVM classification trainer.";
034    }
035
036    @Option(longName = "svm-coefficient", usage = "Intercept in kernel function. Defaults to 0.0.")
037    public double svmCoefficient = 0.0;  //TODO should this be 1.0?
038    @Option(longName = "svm-degree", usage = "Degree in polynomial kernel. Defaults to 3.")
039    public int svmDegree = 3;
040    @Option(longName = "svm-gamma", usage = "Gamma value in kernel function. Defaults to 0.0.")
041    public double svmGamma = 0.0;  //TODO should the default be 0.1
042    @Option(longName = "svm-kernel", usage = "Type of SVM kernel. Defaults to LINEAR.")
043    public KernelType svmKernel = KernelType.LINEAR;
044    @Option(longName = "svm-type", usage = "Type of SVM. Defaults to C_SVC.")
045    public SVMClassificationType.SVMMode svmType = SVMMode.C_SVC;
046
047    @Override
048    public LibSVMClassificationTrainer getTrainer() {
049        SVMParameters<Label> parameters = new SVMParameters<>(new SVMClassificationType(svmType), svmKernel);
050        parameters.setGamma(svmGamma);
051        parameters.setCoeff(svmCoefficient);
052        parameters.setDegree(svmDegree);
053        return new LibSVMClassificationTrainer(parameters);
054    }
055
056}