001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.kernel;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.classification.Label;
027import org.tribuo.math.kernel.Kernel;
028import org.tribuo.math.la.DenseMatrix;
029import org.tribuo.math.la.DenseVector;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032
033import java.util.Collections;
034import java.util.LinkedHashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.Optional;
038
039/**
040 * The inference time version of a kernel model trained using Pegasos.
041 * <p>
042 * See:
043 * <pre>
044 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A
045 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM"
046 * Mathematical Programming, 2011.
047 * </pre>
048 */
049public class KernelSVMModel extends Model<Label> {
050    private static final long serialVersionUID = 2L;
051
052    private final Kernel kernel;
053    private final SparseVector[] supportVectors;
054    private final DenseMatrix weights;
055
056    KernelSVMModel(String name, ModelProvenance description,
057                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap,
058                          Kernel kernel, SparseVector[] supportVectors, DenseMatrix weights) {
059        super(name, description, featureIDMap, labelIDMap, false);
060        this.kernel = kernel;
061        this.supportVectors = supportVectors;
062        this.weights = weights;
063    }
064
065    /**
066     * Returns the number of support vectors used.
067     * @return The number of support vectors.
068     */
069    public int getNumberOfSupportVectors() {
070        return supportVectors.length;
071    }
072
073    @Override
074    public Prediction<Label> predict(Example<Label> example) {
075        SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true);
076        // Due to bias feature
077        if (features.numActiveElements() == 1) {
078            throw new IllegalArgumentException("No features found in Example " + example.toString());
079        }
080        double[] scores = new double[supportVectors.length];
081        for (int i = 0; i < scores.length; i++) {
082            scores[i] = kernel.similarity(features,supportVectors[i]);
083        }
084        DenseVector scoreVector = DenseVector.createDenseVector(scores);
085        DenseVector prediction = weights.leftMultiply(scoreVector);
086
087        double maxScore = Double.NEGATIVE_INFINITY;
088        Label maxLabel = null;
089        Map<String,Label> predMap = new LinkedHashMap<>();
090        for (int i = 0; i < prediction.size(); i++) {
091            String labelName = outputIDInfo.getOutput(i).getLabel();
092            Label label = new Label(labelName, prediction.get(i));
093            predMap.put(labelName, label);
094            if (label.getScore() > maxScore) {
095                maxScore = label.getScore();
096                maxLabel = label;
097            }
098        }
099        return new Prediction<>(maxLabel, predMap, features.numActiveElements(), example, generatesProbabilities);
100    }
101
102    @Override
103    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
104        return Collections.emptyMap();
105    }
106
107    @Override
108    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
109        return Optional.empty();
110    }
111
112    @Override
113    protected KernelSVMModel copy(String newName, ModelProvenance newProvenance) {
114        SparseVector[] vectorCopies = new SparseVector[supportVectors.length];
115        for (int i = 0; i < vectorCopies.length; i++) {
116            vectorCopies[i] = supportVectors[i].copy();
117        }
118        return new KernelSVMModel(newName,newProvenance,featureIDMap,outputIDInfo,kernel,vectorCopies,new DenseMatrix(weights));
119    }
120}