001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.libsvm;
018
019import org.tribuo.Output;
020import libsvm.svm_parameter;
021
022import java.io.Serializable;
023import java.util.Arrays;
024import java.util.logging.Logger;
025
026/**
027 * A container for SVM parameters and the kernel.
028 */
029public class SVMParameters<T extends Output<T>> implements Serializable {
030    private static final long serialVersionUID = 1L;
031    
032    private static final Logger logger = Logger.getLogger(SVMParameters.class.getName());
033
034    protected final SVMType<T> svmType;
035    
036    protected final KernelType kernelType;
037    
038    protected final svm_parameter parameters = new svm_parameter();
039
040    public SVMParameters(SVMType<T> svmType, KernelType kernelType) {
041        this.svmType = svmType;
042        this.kernelType = kernelType;
043        parameters.svm_type = svmType.getNativeType();
044        parameters.kernel_type = kernelType.getNativeType();
045        //
046        // These are defaults, which are only compatible with SVM type
047        // C_SVC and kernel type RBF
048        parameters.degree = 3;
049        parameters.gamma = 0;   // 1/num_features
050        parameters.coef0 = 0;
051        parameters.nu = 0.5;
052        parameters.cache_size = 500;
053        parameters.C = 1;
054        parameters.eps = 1e-3;
055        parameters.p = 0.1;
056        parameters.shrinking = 1;
057        parameters.probability = 0;
058        parameters.nr_weight = 0;
059        parameters.weight_label = new int[0];
060        parameters.weight = new double[0];
061    }
062
063    public SVMType<T> getSvmType() {
064        return svmType;
065    }
066
067    public KernelType getKernelType() {
068        return kernelType;
069    }
070
071    public svm_parameter getParameters() {
072        return parameters;
073    }
074
075    @Override
076    public String toString() {
077        return svmParamsToString(parameters);
078    }
079    
080    /**
081     * Makes the model that is built provide probability estimates.
082     */
083    public void setProbability() {
084        parameters.probability = 1;
085    }
086    
087    public void setCost(double c) {
088        if(svmType.isNu() || !svmType.isClassification()) {
089            logger.warning(String.format("Setting cost %f for non-C_SVC model", c));
090        }
091        parameters.C = c;
092    }
093    
094    public void setNu(double nu) {
095        if(!svmType.isNu()) {
096            logger.warning(String.format("Setting nu %f for non-NU_SVM model", nu));
097        }
098        parameters.nu = nu;
099    }
100
101    public void setCoeff(double coeff) {
102        parameters.coef0 = coeff;
103    }
104
105    public void setEpsilon(double epsilon) {
106        parameters.p = epsilon;
107    }
108
109    public void setDegree(int degree) {
110        parameters.degree = degree;
111    }
112
113    public void setGamma(double gamma) {
114        parameters.gamma = gamma;
115    }
116    
117    public double getGamma() {
118        return parameters.gamma;
119    }
120    
121    public void setCacheSize(double cacheMB) {
122        if(cacheMB <= 0) {
123            throw new IllegalArgumentException("Cache must be larger than 0MB");
124        }
125        parameters.cache_size = cacheMB;
126    }
127
128    /**
129     * Deep copy of the svm_parameters including the arrays.
130     * @param input The parameters to copy.
131     * @return A copy of the svm_parameters.
132     */
133    public static svm_parameter copyParameters(svm_parameter input) {
134        svm_parameter copy = new svm_parameter();
135        copy.svm_type = input.svm_type;
136        copy.kernel_type = input.kernel_type;
137        copy.degree = input.degree;
138        copy.gamma = input.gamma;
139        copy.coef0 = input.coef0;
140        copy.cache_size = input.cache_size;
141        copy.eps = input.eps;
142        copy.C = input.C;
143        copy.nr_weight = input.nr_weight;
144        copy.nu = input.nu;
145        copy.p = input.p;
146        copy.shrinking = input.shrinking;
147        copy.probability = input.probability;
148        copy.weight_label = input.weight_label != null ? Arrays.copyOf(input.weight_label,input.weight_label.length) : null;
149        copy.weight = input.weight != null ? Arrays.copyOf(input.weight,input.weight.length) : null;
150        return copy;
151    }
152
153    /**
154     * A sensible toString for svm_parameter.
155     * @param param The parameters.
156     * @return A String describing the parameters.
157     */
158    public static String svmParamsToString(svm_parameter param) {
159        StringBuilder sb = new StringBuilder();
160        sb.append("svm_parameter(svm_type=");
161        sb.append(param.svm_type);
162        sb.append(", kernel_type=");
163        sb.append(param.kernel_type);
164        sb.append(", degree=");
165        sb.append(param.degree);
166        sb.append(", gamma=");
167        sb.append(param.gamma);
168        sb.append(", coef0=");
169        sb.append(param.coef0);
170        sb.append(", cache_size=");
171        sb.append(param.coef0);
172        sb.append(", eps=");
173        sb.append(param.eps);
174        sb.append(", C=");
175        sb.append(param.C);
176        sb.append(", nr_weight=");
177        sb.append(param.nr_weight);
178        if (param.weight_label != null) {
179            sb.append(", weight_label=");
180            sb.append(Arrays.toString(param.weight_label));
181        }
182        if (param.weight != null) {
183            sb.append(", weight=");
184            sb.append(Arrays.toString(param.weight));
185        }
186        sb.append(", nu=");
187        sb.append(param.nu);
188        sb.append(", p=");
189        sb.append(param.p);
190        sb.append(", shrinking=");
191        sb.append(param.shrinking);
192        sb.append(", probability=");
193        sb.append(param.probability);
194        sb.append(')');
195        return sb.toString();
196    }
197}