001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.liblinear;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Prediction;
027import org.tribuo.classification.Label;
028import org.tribuo.common.liblinear.LibLinearModel;
029import org.tribuo.common.liblinear.LibLinearTrainer;
030import org.tribuo.provenance.ModelProvenance;
031import de.bwaldvogel.liblinear.FeatureNode;
032import de.bwaldvogel.liblinear.Linear;
033
034import java.util.ArrayList;
035import java.util.Collections;
036import java.util.Comparator;
037import java.util.HashMap;
038import java.util.HashSet;
039import java.util.LinkedHashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.PriorityQueue;
043import java.util.Set;
044import java.util.logging.Logger;
045
046/**
047 * A {@link Model} which wraps a LibLinear-java classification model.
048 * <p>
049 * It disables the LibLinear debug output as it's very chatty.
050 * <p>
051 * See:
052 * <pre>
053 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
054 * "LIBLINEAR: A library for Large Linear Classification"
055 * Journal of Machine Learning Research, 2008.
056 * </pre>
057 * and for the original algorithm:
058 * <pre>
059 * Cortes C, Vapnik V.
060 * "Support-Vector Networks"
061 * Machine Learning, 1995.
062 * </pre>
063 */
064public class LibLinearClassificationModel extends LibLinearModel<Label> {
065    private static final long serialVersionUID = 3L;
066
067    private static final Logger logger = Logger.getLogger(LibLinearClassificationModel.class.getName());
068
069    /**
070     * This is used when the model hasn't seen as many outputs as the OutputInfo says are there.
071     * It stores the unseen labels to ensure the predict method has the right number of outputs.
072     * If there are no unobserved labels it's set to Collections.emptySet.
073     */
074    private final Set<Label> unobservedLabels;
075
076    LibLinearClassificationModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, List<de.bwaldvogel.liblinear.Model> models) {
077        super(name, description, featureIDMap, labelIDMap, models.get(0).isProbabilityModel(), models);
078        // This sets up the unobservedLabels variable.
079        int[] curLabels = models.get(0).getLabels();
080        if (curLabels.length != labelIDMap.size()) {
081            Map<Integer,Label> tmp = new HashMap<>();
082            for (Pair<Integer,Label> p : labelIDMap) {
083                tmp.put(p.getA(),p.getB());
084            }
085            for (int i = 0; i < curLabels.length; i++) {
086                tmp.remove(i);
087            }
088            Set<Label> tmpSet = new HashSet<>(tmp.values().size());
089            for (Label l : tmp.values()) {
090                tmpSet.add(new Label(l.getLabel(),0.0));
091            }
092            this.unobservedLabels = Collections.unmodifiableSet(tmpSet);
093        } else {
094            this.unobservedLabels = Collections.emptySet();
095        }
096    }
097
098    @Override
099    public Prediction<Label> predict(Example<Label> example) {
100        FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
101        // Bias feature is always set
102        if (features.length == 1) {
103            throw new IllegalArgumentException("No features found in Example " + example.toString());
104        }
105
106        de.bwaldvogel.liblinear.Model model = models.get(0);
107
108        int[] labels = model.getLabels();
109        double[] scores = new double[labels.length];
110
111        if (model.isProbabilityModel()) {
112            Linear.predictProbability(model, features, scores);
113        } else {
114            Linear.predictValues(model, features, scores);
115            if ((model.getNrClass() == 2) && (scores[1] == 0.0)) {
116                scores[1] = -scores[0];
117            }
118        }
119
120        double maxScore = Double.NEGATIVE_INFINITY;
121        Label maxLabel = null;
122        Map<String,Label> map = new LinkedHashMap<>();
123        for (int i = 0; i < scores.length; i++) {
124            String name = outputIDInfo.getOutput(labels[i]).getLabel();
125            Label label = new Label(name, scores[i]);
126            map.put(name,label);
127            if (label.getScore() > maxScore) {
128                maxScore = label.getScore();
129                maxLabel = label;
130            }
131        }
132        if (!unobservedLabels.isEmpty()) {
133            for (Label l : unobservedLabels) {
134                map.put(l.getLabel(),l);
135            }
136        }
137        return new Prediction<>(maxLabel, map, features.length-1, example, generatesProbabilities);
138    }
139
140    @Override
141    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
142        int maxFeatures = n < 0 ? featureIDMap.size() : n;
143        de.bwaldvogel.liblinear.Model model = models.get(0);
144        int[] labels = model.getLabels();
145        double[] featureWeights = model.getFeatureWeights();
146
147        Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
148        
149        /*
150         * Liblinear stores its weights as follows
151         * +------------------+------------------+------------+
152         * | nr_class weights | nr_class weights |  ...
153         * | for 1st feature  | for 2nd feature  |
154         * +------------------+------------------+------------+
155         *
156         * If bias &gt;= 0, x becomes [x; bias]. The number of features is
157         * increased by one, so w is a (nr_feature+1)*nr_class array. The
158         * value of bias is stored in the variable bias.
159         */
160
161        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
162        int numClasses = model.getNrClass();
163        int numFeatures = model.getNrFeature();
164        if (numClasses == 2) {
165            //
166            // When numClasses == 2, liblinear only stores one set of weights.
167            PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
168
169            for (int i = 0; i < numFeatures; i++) {
170                Pair<String, Double> cur = new Pair<>(featureIDMap.get(i).getName(), featureWeights[i]);
171                if (q.size() < maxFeatures) {
172                    q.offer(cur);
173                } else if (comparator.compare(cur, q.peek()) > 0) {
174                    q.poll();
175                    q.offer(cur);
176                }
177            }
178            List<Pair<String, Double>> list = new ArrayList<>();
179            while (q.size() > 0) {
180                list.add(q.poll());
181            }
182            Collections.reverse(list);
183            map.put(outputIDInfo.getOutput(labels[0]).getLabel(), list);
184
185            List<Pair<String, Double>> otherList = new ArrayList<>();
186            for (Pair<String, Double> f : list) {
187                Pair<String, Double> otherF = new Pair<>(f.getA(), -f.getB());
188                otherList.add(otherF);
189            }
190            map.put(outputIDInfo.getOutput(labels[1]).getLabel(), otherList);
191        } else {
192            for (int i = 0; i < labels.length; i++) {
193                PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
194                //iterate over the non-bias features
195                for (int j = 0; j < numFeatures; j++) {
196                    int index = (j * numClasses) + i;
197                    Pair<String, Double> cur = new Pair<>(featureIDMap.get(j).getName(), featureWeights[index]);
198                    if (q.size() < maxFeatures) {
199                        q.offer(cur);
200                    } else if (comparator.compare(cur, q.peek()) > 0) {
201                        q.poll();
202                        q.offer(cur);
203                    }
204                }
205                List<Pair<String, Double>> list = new ArrayList<>();
206                while (q.size() > 0) {
207                    list.add(q.poll());
208                }
209                Collections.reverse(list);
210                map.put(outputIDInfo.getOutput(labels[i]).getLabel(), list);
211            }
212        }
213        return map;
214    }
215
216    @Override
217    protected LibLinearClassificationModel copy(String newName, ModelProvenance newProvenance) {
218        return new LibLinearClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,Collections.singletonList(copyModel(models.get(0))));
219    }
220
221    @Override
222    protected double[][] getFeatureWeights() {
223        double[][] featureWeights = new double[1][];
224        featureWeights[0] = models.get(0).getFeatureWeights();
225        return featureWeights;
226    }
227
228    /**
229     * The call to model.getFeatureWeights in the public methods copies the
230     * weights array so this inner method exists to save the copy in getExcuses.
231     * <p>
232     * If it becomes a problem then we could cache the feature weights in the
233     * model.
234     * @param e The example.
235     * @param allFeatureWeights The feature weights.
236     * @return An excuse for this example.
237     */
238    @Override
239    protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) {
240        de.bwaldvogel.liblinear.Model model = models.get(0);
241        double[] featureWeights = allFeatureWeights[0];
242        int[] labels = model.getLabels();
243        int numClasses = model.getNrClass();
244
245        Prediction<Label> prediction = predict(e);
246        Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
247
248        if (numClasses == 2) {
249            List<Pair<String, Double>> posScores = new ArrayList<>();
250            List<Pair<String, Double>> negScores = new ArrayList<>();
251            for (Feature f : e) {
252                int id = featureIDMap.getID(f.getName());
253                if (id > -1) {
254                    double score = featureWeights[id] * f.getValue();
255                    posScores.add(new Pair<>(f.getName(), score));
256                    negScores.add(new Pair<>(f.getName(), -score));
257                }
258            }
259            posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
260            negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
261            weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(),posScores);
262            weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(),negScores);
263        } else {
264            for (int i = 0; i < labels.length; i++) {
265                List<Pair<String, Double>> classScores = new ArrayList<>();
266                for (Feature f : e) {
267                    int id = featureIDMap.getID(f.getName());
268                    if (id > -1) {
269                        double score = featureWeights[id * numClasses + i] * f.getValue();
270                        classScores.add(new Pair<>(f.getName(), score));
271                    }
272                }
273                classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
274                weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores);
275            }
276        }
277
278        return new Excuse<>(e, prediction, weightMap);
279    }
280}