001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.libsvm; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.classification.Label; 026import org.tribuo.classification.WeightedLabels; 027import org.tribuo.common.libsvm.LibSVMModel; 028import org.tribuo.common.libsvm.LibSVMTrainer; 029import org.tribuo.common.libsvm.SVMParameters; 030import org.tribuo.provenance.ModelProvenance; 031import libsvm.svm; 032import libsvm.svm_model; 033import libsvm.svm_node; 034import libsvm.svm_parameter; 035import libsvm.svm_problem; 036 037import java.util.ArrayList; 038import java.util.Collections; 039import java.util.HashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.logging.Logger; 043 044/** 045 * A trainer for classification models that uses LibSVM. 046 * <p> 047 * See: 048 * <pre> 049 * Chang CC, Lin CJ. 050 * "LIBSVM: a library for Support Vector Machines" 051 * ACM transactions on intelligent systems and technology (TIST), 2011. 052 * </pre> 053 * for the nu-svc algorithm: 054 * <pre> 055 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 056 * "New support vector algorithms" 057 * Neural Computation, 2000, 1207-1245. 058 * </pre> 059 * and for the original algorithm: 060 * <pre> 061 * Cortes C, Vapnik V. 062 * "Support-Vector Networks" 063 * Machine Learning, 1995. 064 * </pre> 065 */ 066public class LibSVMClassificationTrainer extends LibSVMTrainer<Label> implements WeightedLabels { 067 private static final Logger logger = Logger.getLogger(LibSVMClassificationTrainer.class.getName()); 068 069 @Config(description="Use Label specific weights.") 070 private Map<String,Float> labelWeights = Collections.emptyMap(); 071 072 protected LibSVMClassificationTrainer() {} 073 074 public LibSVMClassificationTrainer(SVMParameters<Label> parameters) { 075 super(parameters); 076 } 077 078 /** 079 * Used by the OLCUT configuration system, and should not be called by external code. 080 */ 081 @Override 082 public void postConfig() { 083 super.postConfig(); 084 if (!svmType.isClassification()) { 085 throw new IllegalArgumentException("Supplied regression or anomaly detection parameters to a classification SVM."); 086 } 087 } 088 089 @Override 090 protected LibSVMModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<svm_model> models) { 091 return new LibSVMClassificationModel("svm-classification-model", provenance, featureIDMap, outputIDInfo, models); 092 } 093 094 @Override 095 protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) { 096 svm_problem problem = new svm_problem(); 097 problem.l = outputs[0].length; 098 problem.x = features; 099 problem.y = outputs[0]; 100 if (curParams.gamma == 0) { 101 curParams.gamma = 1.0 / numFeatures; 102 } 103 String checkString = svm.svm_check_parameter(problem, curParams); 104 if(checkString != null) { 105 throw new IllegalArgumentException("Error checking SVM parameters: " + checkString); 106 } 107 return Collections.singletonList(svm.svm_train(problem, curParams)); 108 } 109 110 @Override 111 protected Pair<svm_node[][], double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) { 112 double[][] ys = new double[1][data.size()]; 113 svm_node[][] xs = new svm_node[data.size()][]; 114 List<svm_node> buffer = new ArrayList<>(); 115 int i = 0; 116 for (Example<Label> example : data) { 117 ys[0][i] = outputInfo.getID(example.getOutput()); 118 xs[i] = exampleToNodes(example, featureMap, buffer); 119 i++; 120 } 121 return new Pair<>(xs,ys); 122 } 123 124 @Override 125 protected svm_parameter setupParameters(ImmutableOutputInfo<Label> outputIDInfo) { 126 svm_parameter curParams = SVMParameters.copyParameters(parameters); 127 if (!labelWeights.isEmpty()) { 128 double[] weights = new double[outputIDInfo.size()]; 129 int[] indices = new int[outputIDInfo.size()]; 130 int i = 0; 131 for (Pair<Integer,Label> label : outputIDInfo) { 132 String labelName = label.getB().getLabel(); 133 Float weight = labelWeights.get(labelName); 134 indices[i] = label.getA(); 135 if (weight != null) { 136 weights[i] = weight; 137 } else { 138 weights[i] = 1.0f; 139 } 140 i++; 141 } 142 curParams.nr_weight = weights.length; 143 curParams.weight = weights; 144 curParams.weight_label = indices; 145 //logger.info("Weights = " + Arrays.toString(weights) + ", labels = " + Arrays.toString(indices) + ", outputIDInfo = " + outputIDInfo); 146 } 147 return curParams; 148 } 149 150 @Override 151 public void setLabelWeights(Map<Label,Float> weights) { 152 labelWeights = new HashMap<>(); 153 for (Map.Entry<Label,Float> e : weights.entrySet()) { 154 labelWeights.put(e.getKey().getLabel(),e.getValue()); 155 } 156 } 157}