001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.VariableInfo; 027import org.tribuo.math.la.DenseVector; 028import org.tribuo.math.la.SparseVector; 029import org.tribuo.math.la.VectorTuple; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.regression.Regressor; 032import org.tribuo.regression.Regressor.DimensionTuple; 033import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel; 034 035import java.util.ArrayList; 036import java.util.Arrays; 037import java.util.Collections; 038import java.util.Comparator; 039import java.util.HashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.Optional; 043import java.util.PriorityQueue; 044import java.util.logging.Logger; 045 046/** 047 * The inference time version of a sparse linear regression model. 048 * <p> 049 * The type of the model depends on the trainer used. 050 */ 051public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel { 052 private static final long serialVersionUID = 3L; 053 private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName()); 054 055 private final SparseVector[] weights; 056 private final DenseVector featureMeans; 057 private final DenseVector featureVariance; 058 private final boolean bias; 059 private final double[] yMean; 060 private final double[] yVariance; 061 062 SparseLinearModel(String name, String[] dimensionNames, ModelProvenance description, 063 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 064 SparseVector[] weights, DenseVector featureMeans, DenseVector featureVariance, double[] yMean, double[] yVariance, boolean bias) { 065 super(name, dimensionNames, description, featureIDMap, labelIDMap, generateActiveFeatures(dimensionNames,featureIDMap,weights)); 066 this.weights = weights; 067 this.featureMeans = featureMeans; 068 this.featureVariance = featureVariance; 069 this.bias = bias; 070 this.yVariance = yVariance; 071 this.yMean = yMean; 072 } 073 074 private static Map<String,List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) { 075 Map<String,List<String>> map = new HashMap<>(); 076 077 for (int i = 0; i < dimensionNames.length; i++) { 078 List<String> featureNames = new ArrayList<>(); 079 for (VectorTuple v : weightsArray[i]) { 080 if (v.index == featureMap.size()) { 081 featureNames.add(BIAS_FEATURE); 082 } else { 083 VariableInfo info = featureMap.get(v.index); 084 featureNames.add(info.getName()); 085 } 086 } 087 map.put(dimensionNames[i],featureNames); 088 } 089 090 return map; 091 } 092 093 /** 094 * Creates the feature vector. Includes a bias term if the model requires it. 095 * @param example The example to convert. 096 * @return The feature vector. 097 */ 098 @Override 099 protected SparseVector createFeatures(Example<Regressor> example) { 100 SparseVector features = SparseVector.createSparseVector(example,featureIDMap,bias); 101 features.intersectAndAddInPlace(featureMeans,(a) -> -a); 102 features.hadamardProductInPlace(featureVariance,(a) -> 1.0/a); 103 return features; 104 } 105 106 @Override 107 protected DimensionTuple scoreDimension(int dimensionIdx, SparseVector features) { 108 double prediction = weights[dimensionIdx].numActiveElements() > 0 ? weights[dimensionIdx].dot(features) : 1.0; 109 prediction *= yVariance[dimensionIdx]; 110 prediction += yMean[dimensionIdx]; 111 return new DimensionTuple(dimensions[dimensionIdx],prediction); 112 } 113 114 @Override 115 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 116 int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n; 117 118 Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 119 120 // 121 // Use a priority queue to find the top N features. 122 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 123 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 124 125 for (int i = 0; i < dimensions.length; i++) { 126 q.clear(); 127 for (VectorTuple v : weights[i]) { 128 VariableInfo info = featureIDMap.get(v.index); 129 String name = info == null ? BIAS_FEATURE : info.getName(); 130 Pair<String, Double> curr = new Pair<>(name, v.value); 131 132 if (q.size() < maxFeatures) { 133 q.offer(curr); 134 } else if (comparator.compare(curr, q.peek()) > 0) { 135 q.poll(); 136 q.offer(curr); 137 } 138 } 139 140 ArrayList<Pair<String, Double>> b = new ArrayList<>(); 141 while (q.size() > 0) { 142 b.add(q.poll()); 143 } 144 145 Collections.reverse(b); 146 map.put(dimensions[i], b); 147 } 148 149 return map; 150 } 151 152 @Override 153 public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) { 154 Prediction<Regressor> prediction = predict(example); 155 Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>(); 156 157 SparseVector features = createFeatures(example); 158 for (int i = 0; i < dimensions.length; i++) { 159 List<Pair<String, Double>> classScores = new ArrayList<>(); 160 for (VectorTuple f : features) { 161 double score = weights[i].get(f.index) * f.value; 162 classScores.add(new Pair<>(featureIDMap.get(f.index).getName(), score)); 163 } 164 classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB())); 165 weightMap.put(dimensions[i], classScores); 166 } 167 168 return Optional.of(new Excuse<>(example, prediction, weightMap)); 169 } 170 171 @Override 172 protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) { 173 return new SparseLinearModel(newName,Arrays.copyOf(dimensions,dimensions.length), 174 newProvenance,featureIDMap,outputIDInfo, 175 copyWeights(), 176 featureMeans.copy(),featureVariance.copy(), 177 Arrays.copyOf(yMean,yMean.length), Arrays.copyOf(yVariance,yVariance.length), bias); 178 } 179 180 private SparseVector[] copyWeights() { 181 SparseVector[] newWeights = new SparseVector[weights.length]; 182 183 for (int i = 0; i < weights.length; i++) { 184 newWeights[i] = weights[i].copy(); 185 } 186 187 return newWeights; 188 } 189 190 /** 191 * Gets a copy of the model parameters. 192 * @return A map from the dimension name to the model parameters. 193 */ 194 public Map<String,SparseVector> getWeights() { 195 SparseVector[] newWeights = copyWeights(); 196 Map<String,SparseVector> output = new HashMap<>(); 197 for (int i = 0; i < dimensions.length; i++) { 198 output.put(dimensions[i],newWeights[i]); 199 } 200 return output; 201 } 202}