001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.libsvm; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.ImmutableFeatureMap; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.common.libsvm.LibSVMModel; 026import org.tribuo.common.libsvm.LibSVMTrainer; 027import org.tribuo.provenance.ModelProvenance; 028import libsvm.svm; 029import libsvm.svm_model; 030import libsvm.svm_node; 031 032import java.util.Collections; 033import java.util.HashMap; 034import java.util.HashSet; 035import java.util.LinkedHashMap; 036import java.util.List; 037import java.util.Map; 038import java.util.Set; 039 040/** 041 * A classification model that uses an underlying LibSVM model to make the 042 * predictions. 043 * <p> 044 * See: 045 * <pre> 046 * Chang CC, Lin CJ. 047 * "LIBSVM: a library for Support Vector Machines" 048 * ACM transactions on intelligent systems and technology (TIST), 2011. 049 * </pre> 050 * for the nu-svc algorithm: 051 * <pre> 052 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 053 * "New support vector algorithms" 054 * Neural Computation, 2000, 1207-1245. 055 * </pre> 056 * and for the original algorithm: 057 * <pre> 058 * Cortes C, Vapnik V. 059 * "Support-Vector Networks" 060 * Machine Learning, 1995. 061 * </pre> 062 */ 063public class LibSVMClassificationModel extends LibSVMModel<Label> { 064 private static final long serialVersionUID = 3L; 065 066 /** 067 * This is used when the model hasn't seen as many outputs as the OutputInfo says are there. 068 * It stores the unseen labels to ensure the predict method has the right number of outputs. 069 * If there are no unobserved labels it's set to Collections.emptySet. 070 */ 071 private final Set<Label> unobservedLabels; 072 073 LibSVMClassificationModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, List<svm_model> models) { 074 super(name, description, featureIDMap, labelIDMap, models.get(0).param.probability == 1, models); 075 // This sets up the unobservedLabels variable. 076 int[] curLabels = models.get(0).label; 077 if (curLabels.length != labelIDMap.size()) { 078 Map<Integer,Label> tmp = new HashMap<>(); 079 for (Pair<Integer,Label> p : labelIDMap) { 080 tmp.put(p.getA(),p.getB()); 081 } 082 for (int i = 0; i < curLabels.length; i++) { 083 tmp.remove(i); 084 } 085 Set<Label> tmpSet = new HashSet<>(tmp.values().size()); 086 for (Label l : tmp.values()) { 087 tmpSet.add(new Label(l.getLabel(),0.0)); 088 } 089 this.unobservedLabels = Collections.unmodifiableSet(tmpSet); 090 } else { 091 this.unobservedLabels = Collections.emptySet(); 092 } 093 } 094 095 public int getNumberOfSupportVectors() { 096 return models.get(0).SV.length; 097 } 098 099 @Override 100 public Prediction<Label> predict(Example<Label> example) { 101 svm_model model = models.get(0); 102 svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null); 103 // Bias feature is always set 104 if (features.length == 0) { 105 throw new IllegalArgumentException("No features found in Example " + example.toString()); 106 } 107 int[] labels = model.label; 108 double[] scores = new double[labels.length]; 109 if (generatesProbabilities) { 110 svm.svm_predict_probability(model, features, scores); 111 } else { 112 //LibSVM returns a one vs one result, and unpacks it into a score vector by voting 113 double[] onevone = new double[labels.length * (labels.length - 1) / 2]; 114 svm.svm_predict_values(model, features, onevone); 115 int counter = 0; 116 for (int i = 0; i < labels.length; i++) { 117 for (int j = i+1; j < labels.length; j++) { 118 if (onevone[counter] > 0) { 119 scores[i]++; 120 } else { 121 scores[j]++; 122 } 123 counter++; 124 } 125 } 126 } 127 double maxScore = Double.NEGATIVE_INFINITY; 128 Label maxLabel = null; 129 Map<String,Label> map = new LinkedHashMap<>(); 130 for (int i = 0; i < scores.length; i++) { 131 String name = outputIDInfo.getOutput(labels[i]).getLabel(); 132 Label label = new Label(name, scores[i]); 133 map.put(name,label); 134 if (label.getScore() > maxScore) { 135 maxScore = label.getScore(); 136 maxLabel = label; 137 } 138 } 139 if (!unobservedLabels.isEmpty()) { 140 for (Label l : unobservedLabels) { 141 map.put(l.getLabel(),l); 142 } 143 } 144 return new Prediction<>(maxLabel, map, features.length, example, generatesProbabilities); 145 } 146 147 @Override 148 protected LibSVMClassificationModel copy(String newName, ModelProvenance newProvenance) { 149 return new LibSVMClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,Collections.singletonList(LibSVMModel.copyModel(models.get(0)))); 150 } 151 152}