001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.libsvm; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import libsvm.svm_model; 023import libsvm.svm_node; 024import libsvm.svm_parameter; 025import org.tribuo.Dataset; 026import org.tribuo.Example; 027import org.tribuo.Feature; 028import org.tribuo.ImmutableFeatureMap; 029import org.tribuo.ImmutableOutputInfo; 030import org.tribuo.Output; 031import org.tribuo.Trainer; 032import org.tribuo.provenance.ModelProvenance; 033import org.tribuo.provenance.TrainerProvenance; 034import org.tribuo.provenance.impl.TrainerProvenanceImpl; 035import org.tribuo.util.Util; 036 037import java.time.OffsetDateTime; 038import java.util.ArrayList; 039import java.util.Arrays; 040import java.util.Collections; 041import java.util.List; 042import java.util.Map; 043import java.util.logging.Logger; 044 045/** 046 * A trainer that will train using libsvm's Java implementation. 047 * <p> 048 * See: 049 * <pre> 050 * Chang CC, Lin CJ. 051 * "LIBSVM: a library for Support Vector Machines" 052 * ACM transactions on intelligent systems and technology (TIST), 2011. 053 * </pre> 054 * for the nu-svm algorithm: 055 * <pre> 056 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 057 * "New support vector algorithms" 058 * Neural Computation, 2000, 1207-1245. 059 * </pre> 060 * and for the original algorithm: 061 * <pre> 062 * Cortes C, Vapnik V. 063 * "Support-Vector Networks" 064 * Machine Learning, 1995. 065 * </pre> 066 */ 067public abstract class LibSVMTrainer<T extends Output<T>> implements Trainer<T> { 068 069 private static final Logger logger = Logger.getLogger(LibSVMTrainer.class.getName()); 070 071 /** 072 * The SVM parameters suitable for use by LibSVM. 073 */ 074 protected svm_parameter parameters; 075 076 /** 077 * The type of SVM algorithm. 078 */ 079 @Config(mandatory=true,description="Type of SVM algorithm.") 080 protected SVMType<T> svmType; 081 082 @Config(description="Type of Kernel.") 083 private KernelType kernelType = KernelType.LINEAR; 084 085 @Config(description="Polynomial degree.") 086 private int degree = 3; 087 088 @Config(description="Width of the RBF kernel, or scalar on sigmoid kernel.") 089 private double gamma = 0.0; 090 091 @Config(description="Polynomial coefficient or shift in sigmoid kernel.") 092 private double coef0 = 0.0; 093 094 @Config(description="nu value in NU SVM.") 095 private double nu = 0.5; 096 097 @Config(description="Internal cache size, most of the time should be left at default.") 098 private double cache_size = 500; 099 100 @Config(description="Cost parameter for incorrect predictions.") 101 private double cost = 1.0; // aka svm_parameters.C 102 103 @Config(description="Tolerance of the termination criterion.") 104 private double eps = 1e-3; 105 106 @Config(description="Epsilon in EPSILON_SVR.") 107 private double p = 1e-3; 108 109 @Config(description="Regularise the weight parameters.") 110 private boolean shrinking = true; 111 112 @Config(description="Generate probability estimates.") 113 private boolean probability = false; 114 115 private int trainInvocationCounter = 0; 116 117 /** 118 * For olcut. 119 */ 120 protected LibSVMTrainer() {} 121 122 /** 123 * Constructs a LibSVMTrainer from the parameters. 124 * @param parameters The SVM parameters. 125 */ 126 protected LibSVMTrainer(SVMParameters<T> parameters) { 127 this.parameters = parameters.getParameters(); 128 // Unpack the parameters for the provenance system. 129 this.svmType = parameters.getSvmType(); 130 this.kernelType = parameters.getKernelType(); 131 this.degree = this.parameters.degree; 132 this.gamma = parameters.getGamma(); 133 this.coef0 = this.parameters.coef0; 134 this.nu = this.parameters.nu; 135 this.cache_size = this.parameters.cache_size; 136 this.cost = this.parameters.C; 137 this.eps = this.parameters.eps; 138 this.p = this.parameters.p; 139 this.shrinking = this.parameters.shrinking == 1; 140 this.probability = this.parameters.probability == 1; 141 } 142 143 /** 144 * Used by the OLCUT configuration system, and should not be called by external code. 145 */ 146 @Override 147 public void postConfig() { 148 parameters = new svm_parameter(); 149 parameters.svm_type = svmType.getNativeType(); 150 parameters.kernel_type = kernelType.getNativeType(); 151 parameters.degree = degree; 152 parameters.gamma = gamma; 153 parameters.coef0 = coef0; 154 parameters.nu = nu; 155 parameters.cache_size = cache_size; 156 parameters.C = cost; 157 parameters.eps = eps; 158 parameters.p = p; 159 parameters.shrinking = shrinking ? 1 : 0; 160 parameters.probability = probability ? 1 : 0; 161 } 162 163 @Override 164 public String toString() { 165 StringBuilder buffer = new StringBuilder(); 166 167 buffer.append("LibSVMTrainer("); 168 buffer.append("svm_params="); 169 buffer.append(SVMParameters.svmParamsToString(parameters)); 170 buffer.append(")"); 171 172 return buffer.toString(); 173 } 174 175 @Override 176 public LibSVMModel<T> train(Dataset<T> examples) { 177 return train(examples, Collections.emptyMap()); 178 } 179 180 @Override 181 public LibSVMModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 182 if (examples.getOutputInfo().getUnknownCount() > 0) { 183 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 184 } 185 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 186 ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo(); 187 188 TrainerProvenance trainerProvenance = getProvenance(); 189 ModelProvenance provenance = new ModelProvenance(LibSVMModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 190 trainInvocationCounter++; 191 192 svm_parameter curParams = setupParameters(outputIDInfo); 193 194 Pair<svm_node[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap); 195 196 List<svm_model> models = trainModels(curParams,featureIDMap.size()+1,data.getA(),data.getB()); 197 198 return createModel(provenance,featureIDMap,outputIDInfo,models); 199 } 200 201 /** 202 * Construct the appropriate subtype of LibSVMModel for the prediction task. 203 * @param provenance The model provenance. 204 * @param featureIDMap The feature id map. 205 * @param outputIDInfo The output id info. 206 * @param models The svm models. 207 * @return An implementation of LibSVMModel. 208 */ 209 protected abstract LibSVMModel<T> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<svm_model> models); 210 211 /** 212 * Train all the liblinear instances necessary for this dataset. 213 * @param curParams The LibLinear parameters. 214 * @param numFeatures The number of features in this dataset. 215 * @param features The features themselves. 216 * @param outputs The outputs. 217 * @return A list of liblinear models. 218 */ 219 protected abstract List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs); 220 221 /** 222 * Extracts the features and {@link Output}s in LibLinear's format. 223 * @param data The input data. 224 * @param outputInfo The output info. 225 * @param featureMap The feature info. 226 * @return The features and outputs. 227 */ 228 protected abstract Pair<svm_node[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap); 229 230 /** 231 * Constructs the svm_parameter. Most of the time this is a no-op, but 232 * classification overrides it to incorporate label weights if they exist. 233 * @param info The output info. 234 * @return The svm_parameters to use for training. 235 */ 236 protected svm_parameter setupParameters(ImmutableOutputInfo<T> info) { 237 return SVMParameters.copyParameters(parameters); 238 } 239 240 @Override 241 public int getInvocationCount() { 242 return trainInvocationCounter; 243 } 244 245 /** 246 * Convert the example into an array of svm_node which represents a sparse feature vector. 247 * <p> 248 * If there are collisions in the feature ids then the values are summed. 249 * @param example The example to convert. 250 * @param featureIDMap The feature id map which holds the indices. 251 * @param features A buffer to use. 252 * @param <T> The type of the ouput. 253 * @return A sparse feature vector. 254 */ 255 public static <T extends Output<T>> svm_node[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<svm_node> features) { 256 if (features == null) { 257 features = new ArrayList<>(); 258 } 259 features.clear(); 260 int prevIdx = -1; 261 for (Feature f : example) { 262 int id = featureIDMap.getID(f.getName()); 263 double value = f.getValue(); 264 if (id > prevIdx){ 265 prevIdx = id; 266 svm_node n = new svm_node(); 267 n.index = id; 268 n.value = value; 269 features.add(n); 270 } else if (id > -1) { 271 // 272 // Collision, deal with it. 273 int collisionIdx = Util.binarySearch(features,id,(svm_node n) -> n.index); 274 if (collisionIdx < 0) { 275 // 276 // Collision but not present in features 277 // move data and bump i 278 collisionIdx = - (collisionIdx + 1); 279 svm_node n = new svm_node(); 280 n.index = id; 281 n.value = value; 282 features.add(collisionIdx,n); 283 } else { 284 // 285 // Collision present in features 286 // add the values. 287 svm_node n = features.get(collisionIdx); 288 n.value += value; 289 } 290 } 291 } 292 return features.toArray(new svm_node[0]); 293 } 294 295 @Override 296 public TrainerProvenance getProvenance() { 297 return new TrainerProvenanceImpl(this); 298 } 299} 300