001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.liblinear; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Prediction; 027import org.tribuo.classification.Label; 028import org.tribuo.common.liblinear.LibLinearModel; 029import org.tribuo.common.liblinear.LibLinearTrainer; 030import org.tribuo.provenance.ModelProvenance; 031import de.bwaldvogel.liblinear.FeatureNode; 032import de.bwaldvogel.liblinear.Linear; 033 034import java.util.ArrayList; 035import java.util.Collections; 036import java.util.Comparator; 037import java.util.HashMap; 038import java.util.HashSet; 039import java.util.LinkedHashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.PriorityQueue; 043import java.util.Set; 044import java.util.logging.Logger; 045 046/** 047 * A {@link Model} which wraps a LibLinear-java classification model. 048 * <p> 049 * It disables the LibLinear debug output as it's very chatty. 050 * <p> 051 * See: 052 * <pre> 053 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ. 054 * "LIBLINEAR: A library for Large Linear Classification" 055 * Journal of Machine Learning Research, 2008. 056 * </pre> 057 * and for the original algorithm: 058 * <pre> 059 * Cortes C, Vapnik V. 060 * "Support-Vector Networks" 061 * Machine Learning, 1995. 062 * </pre> 063 */ 064public class LibLinearClassificationModel extends LibLinearModel<Label> { 065 private static final long serialVersionUID = 3L; 066 067 private static final Logger logger = Logger.getLogger(LibLinearClassificationModel.class.getName()); 068 069 /** 070 * This is used when the model hasn't seen as many outputs as the OutputInfo says are there. 071 * It stores the unseen labels to ensure the predict method has the right number of outputs. 072 * If there are no unobserved labels it's set to Collections.emptySet. 073 */ 074 private final Set<Label> unobservedLabels; 075 076 LibLinearClassificationModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, List<de.bwaldvogel.liblinear.Model> models) { 077 super(name, description, featureIDMap, labelIDMap, models.get(0).isProbabilityModel(), models); 078 // This sets up the unobservedLabels variable. 079 int[] curLabels = models.get(0).getLabels(); 080 if (curLabels.length != labelIDMap.size()) { 081 Map<Integer,Label> tmp = new HashMap<>(); 082 for (Pair<Integer,Label> p : labelIDMap) { 083 tmp.put(p.getA(),p.getB()); 084 } 085 for (int i = 0; i < curLabels.length; i++) { 086 tmp.remove(i); 087 } 088 Set<Label> tmpSet = new HashSet<>(tmp.values().size()); 089 for (Label l : tmp.values()) { 090 tmpSet.add(new Label(l.getLabel(),0.0)); 091 } 092 this.unobservedLabels = Collections.unmodifiableSet(tmpSet); 093 } else { 094 this.unobservedLabels = Collections.emptySet(); 095 } 096 } 097 098 @Override 099 public Prediction<Label> predict(Example<Label> example) { 100 FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null); 101 // Bias feature is always set 102 if (features.length == 1) { 103 throw new IllegalArgumentException("No features found in Example " + example.toString()); 104 } 105 106 de.bwaldvogel.liblinear.Model model = models.get(0); 107 108 int[] labels = model.getLabels(); 109 double[] scores = new double[labels.length]; 110 111 if (model.isProbabilityModel()) { 112 Linear.predictProbability(model, features, scores); 113 } else { 114 Linear.predictValues(model, features, scores); 115 if ((model.getNrClass() == 2) && (scores[1] == 0.0)) { 116 scores[1] = -scores[0]; 117 } 118 } 119 120 double maxScore = Double.NEGATIVE_INFINITY; 121 Label maxLabel = null; 122 Map<String,Label> map = new LinkedHashMap<>(); 123 for (int i = 0; i < scores.length; i++) { 124 String name = outputIDInfo.getOutput(labels[i]).getLabel(); 125 Label label = new Label(name, scores[i]); 126 map.put(name,label); 127 if (label.getScore() > maxScore) { 128 maxScore = label.getScore(); 129 maxLabel = label; 130 } 131 } 132 if (!unobservedLabels.isEmpty()) { 133 for (Label l : unobservedLabels) { 134 map.put(l.getLabel(),l); 135 } 136 } 137 return new Prediction<>(maxLabel, map, features.length-1, example, generatesProbabilities); 138 } 139 140 @Override 141 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 142 int maxFeatures = n < 0 ? featureIDMap.size() : n; 143 de.bwaldvogel.liblinear.Model model = models.get(0); 144 int[] labels = model.getLabels(); 145 double[] featureWeights = model.getFeatureWeights(); 146 147 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 148 149 /* 150 * Liblinear stores its weights as follows 151 * +------------------+------------------+------------+ 152 * | nr_class weights | nr_class weights | ... 153 * | for 1st feature | for 2nd feature | 154 * +------------------+------------------+------------+ 155 * 156 * If bias >= 0, x becomes [x; bias]. The number of features is 157 * increased by one, so w is a (nr_feature+1)*nr_class array. The 158 * value of bias is stored in the variable bias. 159 */ 160 161 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 162 int numClasses = model.getNrClass(); 163 int numFeatures = model.getNrFeature(); 164 if (numClasses == 2) { 165 // 166 // When numClasses == 2, liblinear only stores one set of weights. 167 PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator); 168 169 for (int i = 0; i < numFeatures; i++) { 170 Pair<String, Double> cur = new Pair<>(featureIDMap.get(i).getName(), featureWeights[i]); 171 if (q.size() < maxFeatures) { 172 q.offer(cur); 173 } else if (comparator.compare(cur, q.peek()) > 0) { 174 q.poll(); 175 q.offer(cur); 176 } 177 } 178 List<Pair<String, Double>> list = new ArrayList<>(); 179 while (q.size() > 0) { 180 list.add(q.poll()); 181 } 182 Collections.reverse(list); 183 map.put(outputIDInfo.getOutput(labels[0]).getLabel(), list); 184 185 List<Pair<String, Double>> otherList = new ArrayList<>(); 186 for (Pair<String, Double> f : list) { 187 Pair<String, Double> otherF = new Pair<>(f.getA(), -f.getB()); 188 otherList.add(otherF); 189 } 190 map.put(outputIDInfo.getOutput(labels[1]).getLabel(), otherList); 191 } else { 192 for (int i = 0; i < labels.length; i++) { 193 PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator); 194 //iterate over the non-bias features 195 for (int j = 0; j < numFeatures; j++) { 196 int index = (j * numClasses) + i; 197 Pair<String, Double> cur = new Pair<>(featureIDMap.get(j).getName(), featureWeights[index]); 198 if (q.size() < maxFeatures) { 199 q.offer(cur); 200 } else if (comparator.compare(cur, q.peek()) > 0) { 201 q.poll(); 202 q.offer(cur); 203 } 204 } 205 List<Pair<String, Double>> list = new ArrayList<>(); 206 while (q.size() > 0) { 207 list.add(q.poll()); 208 } 209 Collections.reverse(list); 210 map.put(outputIDInfo.getOutput(labels[i]).getLabel(), list); 211 } 212 } 213 return map; 214 } 215 216 @Override 217 protected LibLinearClassificationModel copy(String newName, ModelProvenance newProvenance) { 218 return new LibLinearClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,Collections.singletonList(copyModel(models.get(0)))); 219 } 220 221 @Override 222 protected double[][] getFeatureWeights() { 223 double[][] featureWeights = new double[1][]; 224 featureWeights[0] = models.get(0).getFeatureWeights(); 225 return featureWeights; 226 } 227 228 /** 229 * The call to model.getFeatureWeights in the public methods copies the 230 * weights array so this inner method exists to save the copy in getExcuses. 231 * <p> 232 * If it becomes a problem then we could cache the feature weights in the 233 * model. 234 * @param e The example. 235 * @param allFeatureWeights The feature weights. 236 * @return An excuse for this example. 237 */ 238 @Override 239 protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) { 240 de.bwaldvogel.liblinear.Model model = models.get(0); 241 double[] featureWeights = allFeatureWeights[0]; 242 int[] labels = model.getLabels(); 243 int numClasses = model.getNrClass(); 244 245 Prediction<Label> prediction = predict(e); 246 Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>(); 247 248 if (numClasses == 2) { 249 List<Pair<String, Double>> posScores = new ArrayList<>(); 250 List<Pair<String, Double>> negScores = new ArrayList<>(); 251 for (Feature f : e) { 252 int id = featureIDMap.getID(f.getName()); 253 if (id > -1) { 254 double score = featureWeights[id] * f.getValue(); 255 posScores.add(new Pair<>(f.getName(), score)); 256 negScores.add(new Pair<>(f.getName(), -score)); 257 } 258 } 259 posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB())); 260 negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB())); 261 weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(),posScores); 262 weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(),negScores); 263 } else { 264 for (int i = 0; i < labels.length; i++) { 265 List<Pair<String, Double>> classScores = new ArrayList<>(); 266 for (Feature f : e) { 267 int id = featureIDMap.getID(f.getName()); 268 if (id > -1) { 269 double score = featureWeights[id * numClasses + i] * f.getValue(); 270 classScores.add(new Pair<>(f.getName(), score)); 271 } 272 } 273 classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB())); 274 weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores); 275 } 276 } 277 278 return new Excuse<>(e, prediction, weightMap); 279 } 280}