001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.sgd.linear; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Prediction; 027import org.tribuo.math.LinearParameters; 028import org.tribuo.math.la.DenseMatrix; 029import org.tribuo.math.la.DenseVector; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.regression.Regressor; 033 034import java.util.ArrayList; 035import java.util.Arrays; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.HashMap; 039import java.util.List; 040import java.util.Map; 041import java.util.Optional; 042import java.util.PriorityQueue; 043 044/** 045 * The inference time version of a linear model trained using SGD. 046 * The output dimensions are independent, unless they are tied together by the 047 * optimiser. 048 * <p> 049 * See: 050 * <pre> 051 * Bottou L. 052 * "Large-Scale Machine Learning with Stochastic Gradient Descent" 053 * Proceedings of COMPSTAT, 2010. 054 * </pre> 055 */ 056public class LinearSGDModel extends Model<Regressor> { 057 private static final long serialVersionUID = 3L; 058 059 private final String[] dimensionNames; 060 private final DenseMatrix weights; 061 062 LinearSGDModel(String name, String[] dimensionNames, ModelProvenance description, 063 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 064 LinearParameters parameters) { 065 super(name, description, featureIDMap, labelIDMap, false); 066 this.weights = parameters.getWeightMatrix(); 067 this.dimensionNames = dimensionNames; 068 } 069 070 private LinearSGDModel(String name, String[] dimensionNames, ModelProvenance description, 071 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 072 DenseMatrix weights) { 073 super(name, description, featureIDMap, labelIDMap, false); 074 this.weights = weights; 075 this.dimensionNames = dimensionNames; 076 } 077 078 @Override 079 public Prediction<Regressor> predict(Example<Regressor> example) { 080 SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true); 081 if (features.numActiveElements() == 1) { 082 throw new IllegalArgumentException("No features found in Example " + example.toString()); 083 } 084 DenseVector prediction = weights.leftMultiply(features); 085 return new Prediction<>(new Regressor(dimensionNames,prediction.toArray()), features.numActiveElements(), example); 086 } 087 088 @Override 089 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 090 int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n; 091 092 Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 093 094 // 095 // Use a priority queue to find the top N features. 096 int numClasses = weights.getDimension1Size(); 097 int numFeatures = weights.getDimension2Size()-1; //Removing the bias feature. 098 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 099 for (int i = 0; i < numClasses; i++) { 100 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 101 102 for (int j = 0; j < numFeatures; j++) { 103 Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), weights.get(i,j)); 104 105 if (q.size() < maxFeatures) { 106 q.offer(curr); 107 } else if (comparator.compare(curr, q.peek()) > 0) { 108 q.poll(); 109 q.offer(curr); 110 } 111 } 112 Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures)); 113 114 if (q.size() < maxFeatures) { 115 q.offer(curr); 116 } else if (comparator.compare(curr, q.peek()) > 0) { 117 q.poll(); 118 q.offer(curr); 119 } 120 ArrayList<Pair<String,Double>> b = new ArrayList<>(); 121 while (q.size() > 0) { 122 b.add(q.poll()); 123 } 124 125 Collections.reverse(b); 126 map.put(dimensionNames[i], b); 127 } 128 return map; 129 } 130 131 @Override 132 public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) { 133 Prediction<Regressor> prediction = predict(example); 134 Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>(); 135 int numOutputs = weights.getDimension1Size(); 136 int numFeatures = weights.getDimension2Size()-1; //Remove bias feature 137 138 for (int i = 0; i < numOutputs; i++) { 139 List<Pair<String, Double>> classScores = new ArrayList<>(); 140 for (Feature f : example) { 141 int id = featureIDMap.getID(f.getName()); 142 if (id > -1) { 143 double score = weights.get(i,id) * f.getValue(); 144 classScores.add(new Pair<>(f.getName(), score)); 145 } 146 } 147 classScores.add(new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures))); 148 classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB())); 149 weightMap.put(dimensionNames[i], classScores); 150 } 151 152 return Optional.of(new Excuse<>(example, prediction, weightMap)); 153 } 154 155 @Override 156 protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) { 157 return new LinearSGDModel(newName,Arrays.copyOf(dimensionNames,dimensionNames.length),newProvenance,featureIDMap,outputIDInfo,getWeightsCopy()); 158 } 159 160 public DenseMatrix getWeightsCopy() { 161 return weights.copy(); 162 } 163}