001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.linear; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Prediction; 027import org.tribuo.classification.Label; 028import org.tribuo.math.LinearParameters; 029import org.tribuo.math.la.DenseMatrix; 030import org.tribuo.math.la.DenseVector; 031import org.tribuo.math.la.SparseVector; 032import org.tribuo.math.util.VectorNormalizer; 033import org.tribuo.provenance.ModelProvenance; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.HashMap; 039import java.util.LinkedHashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.Optional; 043import java.util.PriorityQueue; 044 045/** 046 * The inference time version of a linear model trained using SGD. 047 * <p> 048 * See: 049 * <pre> 050 * Bottou L. 051 * "Large-Scale Machine Learning with Stochastic Gradient Descent" 052 * Proceedings of COMPSTAT, 2010. 053 * </pre> 054 */ 055public class LinearSGDModel extends Model<Label> { 056 private static final long serialVersionUID = 2L; 057 058 private final DenseMatrix weights; 059 private final VectorNormalizer normalizer; 060 061 LinearSGDModel(String name, ModelProvenance description, 062 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, 063 LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) { 064 super(name, description, featureIDMap, labelIDMap, generatesProbabilities); 065 this.weights = parameters.getWeightMatrix(); 066 this.normalizer = normalizer; 067 } 068 069 private LinearSGDModel(String name, ModelProvenance description, 070 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, 071 DenseMatrix weights, VectorNormalizer normalizer, boolean generatesProbabilities) { 072 super(name, description, featureIDMap, labelIDMap, generatesProbabilities); 073 this.weights = weights; 074 this.normalizer = normalizer; 075 } 076 077 @Override 078 public Prediction<Label> predict(Example<Label> example) { 079 SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true); 080 // Due to bias feature 081 if (features.numActiveElements() == 1) { 082 throw new IllegalArgumentException("No features found in Example " + example.toString()); 083 } 084 DenseVector prediction = weights.leftMultiply(features); 085 prediction.normalize(normalizer); 086 087 double maxScore = Double.NEGATIVE_INFINITY; 088 Label maxLabel = null; 089 Map<String,Label> predMap = new LinkedHashMap<>(); 090 for (int i = 0; i < prediction.size(); i++) { 091 String labelName = outputIDInfo.getOutput(i).getLabel(); 092 Label label = new Label(labelName, prediction.get(i)); 093 predMap.put(labelName,label); 094 if (label.getScore() > maxScore) { 095 maxScore = label.getScore(); 096 maxLabel = label; 097 } 098 } 099 return new Prediction<>(maxLabel, predMap, features.numActiveElements()-1, example, generatesProbabilities); 100 } 101 102 @Override 103 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 104 int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n; 105 106 Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 107 108 // 109 // Use a priority queue to find the top N features. 110 int numClasses = weights.getDimension1Size(); 111 int numFeatures = weights.getDimension2Size()-1; //Removing the bias feature. 112 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 113 for (int i = 0; i < numClasses; i++) { 114 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 115 116 for (int j = 0; j < numFeatures; j++) { 117 Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), weights.get(i,j)); 118 119 if (q.size() < maxFeatures) { 120 q.offer(curr); 121 } else if (comparator.compare(curr, q.peek()) > 0) { 122 q.poll(); 123 q.offer(curr); 124 } 125 } 126 Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures)); 127 128 if (q.size() < maxFeatures) { 129 q.offer(curr); 130 } else if (comparator.compare(curr, q.peek()) > 0) { 131 q.poll(); 132 q.offer(curr); 133 } 134 ArrayList<Pair<String,Double>> b = new ArrayList<>(); 135 while (q.size() > 0) { 136 b.add(q.poll()); 137 } 138 139 Collections.reverse(b); 140 map.put(outputIDInfo.getOutput(i).getLabel(), b); 141 } 142 return map; 143 } 144 145 @Override 146 public Optional<Excuse<Label>> getExcuse(Example<Label> example) { 147 Prediction<Label> prediction = predict(example); 148 Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>(); 149 int numClasses = weights.getDimension1Size(); 150 int numFeatures = weights.getDimension2Size()-1; 151 152 for (int i = 0; i < numClasses; i++) { 153 List<Pair<String, Double>> classScores = new ArrayList<>(); 154 for (Feature f : example) { 155 int id = featureIDMap.getID(f.getName()); 156 if (id > -1) { 157 double score = weights.get(i,id) * f.getValue(); 158 classScores.add(new Pair<>(f.getName(), score)); 159 } 160 } 161 classScores.add(new Pair<>(Model.BIAS_FEATURE,weights.get(i,numFeatures))); 162 classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB())); 163 weightMap.put(outputIDInfo.getOutput(i).getLabel(), classScores); 164 } 165 166 return Optional.of(new Excuse<>(example, prediction, weightMap)); 167 } 168 169 @Override 170 protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) { 171 return new LinearSGDModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseMatrix(weights),normalizer,generatesProbabilities); 172 } 173}