001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.kernel; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.classification.Label; 027import org.tribuo.math.kernel.Kernel; 028import org.tribuo.math.la.DenseMatrix; 029import org.tribuo.math.la.DenseVector; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032 033import java.util.Collections; 034import java.util.LinkedHashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.Optional; 038 039/** 040 * The inference time version of a kernel model trained using Pegasos. 041 * <p> 042 * See: 043 * <pre> 044 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A 045 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM" 046 * Mathematical Programming, 2011. 047 * </pre> 048 */ 049public class KernelSVMModel extends Model<Label> { 050 private static final long serialVersionUID = 2L; 051 052 private final Kernel kernel; 053 private final SparseVector[] supportVectors; 054 private final DenseMatrix weights; 055 056 KernelSVMModel(String name, ModelProvenance description, 057 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, 058 Kernel kernel, SparseVector[] supportVectors, DenseMatrix weights) { 059 super(name, description, featureIDMap, labelIDMap, false); 060 this.kernel = kernel; 061 this.supportVectors = supportVectors; 062 this.weights = weights; 063 } 064 065 /** 066 * Returns the number of support vectors used. 067 * @return The number of support vectors. 068 */ 069 public int getNumberOfSupportVectors() { 070 return supportVectors.length; 071 } 072 073 @Override 074 public Prediction<Label> predict(Example<Label> example) { 075 SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true); 076 // Due to bias feature 077 if (features.numActiveElements() == 1) { 078 throw new IllegalArgumentException("No features found in Example " + example.toString()); 079 } 080 double[] scores = new double[supportVectors.length]; 081 for (int i = 0; i < scores.length; i++) { 082 scores[i] = kernel.similarity(features,supportVectors[i]); 083 } 084 DenseVector scoreVector = DenseVector.createDenseVector(scores); 085 DenseVector prediction = weights.leftMultiply(scoreVector); 086 087 double maxScore = Double.NEGATIVE_INFINITY; 088 Label maxLabel = null; 089 Map<String,Label> predMap = new LinkedHashMap<>(); 090 for (int i = 0; i < prediction.size(); i++) { 091 String labelName = outputIDInfo.getOutput(i).getLabel(); 092 Label label = new Label(labelName, prediction.get(i)); 093 predMap.put(labelName, label); 094 if (label.getScore() > maxScore) { 095 maxScore = label.getScore(); 096 maxLabel = label; 097 } 098 } 099 return new Prediction<>(maxLabel, predMap, features.numActiveElements(), example, generatesProbabilities); 100 } 101 102 @Override 103 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 104 return Collections.emptyMap(); 105 } 106 107 @Override 108 public Optional<Excuse<Label>> getExcuse(Example<Label> example) { 109 return Optional.empty(); 110 } 111 112 @Override 113 protected KernelSVMModel copy(String newName, ModelProvenance newProvenance) { 114 SparseVector[] vectorCopies = new SparseVector[supportVectors.length]; 115 for (int i = 0; i < vectorCopies.length; i++) { 116 vectorCopies[i] = supportVectors[i].copy(); 117 } 118 return new KernelSVMModel(newName,newProvenance,featureIDMap,outputIDInfo,kernel,vectorCopies,new DenseMatrix(weights)); 119 } 120}