001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.libsvm;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.ImmutableFeatureMap;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.common.libsvm.LibSVMModel;
026import org.tribuo.common.libsvm.LibSVMTrainer;
027import org.tribuo.provenance.ModelProvenance;
028import libsvm.svm;
029import libsvm.svm_model;
030import libsvm.svm_node;
031
032import java.util.Collections;
033import java.util.HashMap;
034import java.util.HashSet;
035import java.util.LinkedHashMap;
036import java.util.List;
037import java.util.Map;
038import java.util.Set;
039
040/**
041 * A classification model that uses an underlying LibSVM model to make the
042 * predictions.
043 * <p>
044 * See:
045 * <pre>
046 * Chang CC, Lin CJ.
047 * "LIBSVM: a library for Support Vector Machines"
048 * ACM transactions on intelligent systems and technology (TIST), 2011.
049 * </pre>
050 * for the nu-svc algorithm:
051 * <pre>
052 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
053 * "New support vector algorithms"
054 * Neural Computation, 2000, 1207-1245.
055 * </pre>
056 * and for the original algorithm:
057 * <pre>
058 * Cortes C, Vapnik V.
059 * "Support-Vector Networks"
060 * Machine Learning, 1995.
061 * </pre>
062 */
063public class LibSVMClassificationModel extends LibSVMModel<Label> {
064    private static final long serialVersionUID = 3L;
065
066    /**
067     * This is used when the model hasn't seen as many outputs as the OutputInfo says are there.
068     * It stores the unseen labels to ensure the predict method has the right number of outputs.
069     * If there are no unobserved labels it's set to Collections.emptySet.
070     */
071    private final Set<Label> unobservedLabels;
072
073    LibSVMClassificationModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, List<svm_model> models) {
074        super(name, description, featureIDMap, labelIDMap, models.get(0).param.probability == 1, models);
075        // This sets up the unobservedLabels variable.
076        int[] curLabels = models.get(0).label;
077        if (curLabels.length != labelIDMap.size()) {
078            Map<Integer,Label> tmp = new HashMap<>();
079            for (Pair<Integer,Label> p : labelIDMap) {
080                tmp.put(p.getA(),p.getB());
081            }
082            for (int i = 0; i < curLabels.length; i++) {
083                tmp.remove(i);
084            }
085            Set<Label> tmpSet = new HashSet<>(tmp.values().size());
086            for (Label l : tmp.values()) {
087                tmpSet.add(new Label(l.getLabel(),0.0));
088            }
089            this.unobservedLabels = Collections.unmodifiableSet(tmpSet);
090        } else {
091            this.unobservedLabels = Collections.emptySet();
092        }
093    }
094
095    public int getNumberOfSupportVectors() {
096        return models.get(0).SV.length;
097    }
098
099    @Override
100    public Prediction<Label> predict(Example<Label> example) {
101        svm_model model = models.get(0);
102        svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null);
103        // Bias feature is always set
104        if (features.length == 0) {
105            throw new IllegalArgumentException("No features found in Example " + example.toString());
106        }
107        int[] labels = model.label;
108        double[] scores = new double[labels.length];
109        if (generatesProbabilities) {
110            svm.svm_predict_probability(model, features, scores);
111        } else {
112            //LibSVM returns a one vs one result, and unpacks it into a score vector by voting
113            double[] onevone = new double[labels.length * (labels.length - 1) / 2];
114            svm.svm_predict_values(model, features, onevone);
115            int counter = 0;
116            for (int i = 0; i < labels.length; i++) {
117                for (int j = i+1; j < labels.length; j++) {
118                    if (onevone[counter] > 0) {
119                        scores[i]++;
120                    } else {
121                        scores[j]++;
122                    }
123                    counter++;
124                }
125            }
126        }
127        double maxScore = Double.NEGATIVE_INFINITY;
128        Label maxLabel = null;
129        Map<String,Label> map = new LinkedHashMap<>();
130        for (int i = 0; i < scores.length; i++) {
131            String name = outputIDInfo.getOutput(labels[i]).getLabel();
132            Label label = new Label(name, scores[i]);
133            map.put(name,label);
134            if (label.getScore() > maxScore) {
135                maxScore = label.getScore();
136                maxLabel = label;
137            }
138        }
139        if (!unobservedLabels.isEmpty()) {
140            for (Label l : unobservedLabels) {
141                map.put(l.getLabel(),l);
142            }
143        }
144        return new Prediction<>(maxLabel, map, features.length, example, generatesProbabilities);
145    }
146
147    @Override
148    protected LibSVMClassificationModel copy(String newName, ModelProvenance newProvenance) {
149        return new LibSVMClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,Collections.singletonList(LibSVMModel.copyModel(models.get(0))));
150    }
151
152}