001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.libsvm; 018 019import org.tribuo.Output; 020import libsvm.svm_parameter; 021 022import java.io.Serializable; 023import java.util.Arrays; 024import java.util.logging.Logger; 025 026/** 027 * A container for SVM parameters and the kernel. 028 */ 029public class SVMParameters<T extends Output<T>> implements Serializable { 030 private static final long serialVersionUID = 1L; 031 032 private static final Logger logger = Logger.getLogger(SVMParameters.class.getName()); 033 034 protected final SVMType<T> svmType; 035 036 protected final KernelType kernelType; 037 038 protected final svm_parameter parameters = new svm_parameter(); 039 040 public SVMParameters(SVMType<T> svmType, KernelType kernelType) { 041 this.svmType = svmType; 042 this.kernelType = kernelType; 043 parameters.svm_type = svmType.getNativeType(); 044 parameters.kernel_type = kernelType.getNativeType(); 045 // 046 // These are defaults, which are only compatible with SVM type 047 // C_SVC and kernel type RBF 048 parameters.degree = 3; 049 parameters.gamma = 0; // 1/num_features 050 parameters.coef0 = 0; 051 parameters.nu = 0.5; 052 parameters.cache_size = 500; 053 parameters.C = 1; 054 parameters.eps = 1e-3; 055 parameters.p = 0.1; 056 parameters.shrinking = 1; 057 parameters.probability = 0; 058 parameters.nr_weight = 0; 059 parameters.weight_label = new int[0]; 060 parameters.weight = new double[0]; 061 } 062 063 public SVMType<T> getSvmType() { 064 return svmType; 065 } 066 067 public KernelType getKernelType() { 068 return kernelType; 069 } 070 071 public svm_parameter getParameters() { 072 return parameters; 073 } 074 075 @Override 076 public String toString() { 077 return svmParamsToString(parameters); 078 } 079 080 /** 081 * Makes the model that is built provide probability estimates. 082 */ 083 public void setProbability() { 084 parameters.probability = 1; 085 } 086 087 public void setCost(double c) { 088 if(svmType.isNu() || !svmType.isClassification()) { 089 logger.warning(String.format("Setting cost %f for non-C_SVC model", c)); 090 } 091 parameters.C = c; 092 } 093 094 public void setNu(double nu) { 095 if(!svmType.isNu()) { 096 logger.warning(String.format("Setting nu %f for non-NU_SVM model", nu)); 097 } 098 parameters.nu = nu; 099 } 100 101 public void setCoeff(double coeff) { 102 parameters.coef0 = coeff; 103 } 104 105 public void setEpsilon(double epsilon) { 106 parameters.p = epsilon; 107 } 108 109 public void setDegree(int degree) { 110 parameters.degree = degree; 111 } 112 113 public void setGamma(double gamma) { 114 parameters.gamma = gamma; 115 } 116 117 public double getGamma() { 118 return parameters.gamma; 119 } 120 121 public void setCacheSize(double cacheMB) { 122 if(cacheMB <= 0) { 123 throw new IllegalArgumentException("Cache must be larger than 0MB"); 124 } 125 parameters.cache_size = cacheMB; 126 } 127 128 /** 129 * Deep copy of the svm_parameters including the arrays. 130 * @param input The parameters to copy. 131 * @return A copy of the svm_parameters. 132 */ 133 public static svm_parameter copyParameters(svm_parameter input) { 134 svm_parameter copy = new svm_parameter(); 135 copy.svm_type = input.svm_type; 136 copy.kernel_type = input.kernel_type; 137 copy.degree = input.degree; 138 copy.gamma = input.gamma; 139 copy.coef0 = input.coef0; 140 copy.cache_size = input.cache_size; 141 copy.eps = input.eps; 142 copy.C = input.C; 143 copy.nr_weight = input.nr_weight; 144 copy.nu = input.nu; 145 copy.p = input.p; 146 copy.shrinking = input.shrinking; 147 copy.probability = input.probability; 148 copy.weight_label = input.weight_label != null ? Arrays.copyOf(input.weight_label,input.weight_label.length) : null; 149 copy.weight = input.weight != null ? Arrays.copyOf(input.weight,input.weight.length) : null; 150 return copy; 151 } 152 153 /** 154 * A sensible toString for svm_parameter. 155 * @param param The parameters. 156 * @return A String describing the parameters. 157 */ 158 public static String svmParamsToString(svm_parameter param) { 159 StringBuilder sb = new StringBuilder(); 160 sb.append("svm_parameter(svm_type="); 161 sb.append(param.svm_type); 162 sb.append(", kernel_type="); 163 sb.append(param.kernel_type); 164 sb.append(", degree="); 165 sb.append(param.degree); 166 sb.append(", gamma="); 167 sb.append(param.gamma); 168 sb.append(", coef0="); 169 sb.append(param.coef0); 170 sb.append(", cache_size="); 171 sb.append(param.coef0); 172 sb.append(", eps="); 173 sb.append(param.eps); 174 sb.append(", C="); 175 sb.append(param.C); 176 sb.append(", nr_weight="); 177 sb.append(param.nr_weight); 178 if (param.weight_label != null) { 179 sb.append(", weight_label="); 180 sb.append(Arrays.toString(param.weight_label)); 181 } 182 if (param.weight != null) { 183 sb.append(", weight="); 184 sb.append(Arrays.toString(param.weight)); 185 } 186 sb.append(", nu="); 187 sb.append(param.nu); 188 sb.append(", p="); 189 sb.append(param.p); 190 sb.append(", shrinking="); 191 sb.append(param.shrinking); 192 sb.append(", probability="); 193 sb.append(param.probability); 194 sb.append(')'); 195 return sb.toString(); 196 } 197}