001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.VariableInfo;
027import org.tribuo.math.la.DenseVector;
028import org.tribuo.math.la.SparseVector;
029import org.tribuo.math.la.VectorTuple;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.regression.Regressor;
032import org.tribuo.regression.Regressor.DimensionTuple;
033import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
034
035import java.util.ArrayList;
036import java.util.Arrays;
037import java.util.Collections;
038import java.util.Comparator;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.Optional;
043import java.util.PriorityQueue;
044import java.util.logging.Logger;
045
046/**
047 * The inference time version of a sparse linear regression model.
048 * <p>
049 * The type of the model depends on the trainer used.
050 */
051public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel {
052    private static final long serialVersionUID = 3L;
053    private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());
054
055    private final SparseVector[] weights;
056    private final DenseVector featureMeans;
057    private final DenseVector featureVariance;
058    private final boolean bias;
059    private final double[] yMean;
060    private final double[] yVariance;
061
062    SparseLinearModel(String name, String[] dimensionNames, ModelProvenance description,
063                      ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
064                      SparseVector[] weights, DenseVector featureMeans, DenseVector featureVariance, double[] yMean, double[] yVariance, boolean bias) {
065        super(name, dimensionNames, description, featureIDMap, labelIDMap, generateActiveFeatures(dimensionNames,featureIDMap,weights));
066        this.weights = weights;
067        this.featureMeans = featureMeans;
068        this.featureVariance = featureVariance;
069        this.bias = bias;
070        this.yVariance = yVariance;
071        this.yMean = yMean;
072    }
073
074    private static Map<String,List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) {
075        Map<String,List<String>> map = new HashMap<>();
076
077        for (int i = 0; i < dimensionNames.length; i++) {
078            List<String> featureNames = new ArrayList<>();
079            for (VectorTuple v : weightsArray[i]) {
080                if (v.index == featureMap.size()) {
081                    featureNames.add(BIAS_FEATURE);
082                } else {
083                    VariableInfo info = featureMap.get(v.index);
084                    featureNames.add(info.getName());
085                }
086            }
087            map.put(dimensionNames[i],featureNames);
088        }
089
090        return map;
091    }
092
093    /**
094     * Creates the feature vector. Includes a bias term if the model requires it.
095     * @param example The example to convert.
096     * @return The feature vector.
097     */
098    @Override
099    protected SparseVector createFeatures(Example<Regressor> example) {
100        SparseVector features = SparseVector.createSparseVector(example,featureIDMap,bias);
101        features.intersectAndAddInPlace(featureMeans,(a) -> -a);
102        features.hadamardProductInPlace(featureVariance,(a) -> 1.0/a);
103        return features;
104    }
105
106    @Override
107    protected DimensionTuple scoreDimension(int dimensionIdx, SparseVector features) {
108        double prediction = weights[dimensionIdx].numActiveElements() > 0 ? weights[dimensionIdx].dot(features) : 1.0;
109        prediction *= yVariance[dimensionIdx];
110        prediction += yMean[dimensionIdx];
111        return new DimensionTuple(dimensions[dimensionIdx],prediction);
112    }
113
114    @Override
115    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
116        int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n;
117
118        Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
119
120        //
121        // Use a priority queue to find the top N features.
122        Map<String, List<Pair<String,Double>>> map = new HashMap<>();
123        PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
124
125        for (int i = 0; i < dimensions.length; i++) {
126            q.clear();
127            for (VectorTuple v : weights[i]) {
128                VariableInfo info = featureIDMap.get(v.index);
129                String name = info == null ? BIAS_FEATURE : info.getName();
130                Pair<String, Double> curr = new Pair<>(name, v.value);
131
132                if (q.size() < maxFeatures) {
133                    q.offer(curr);
134                } else if (comparator.compare(curr, q.peek()) > 0) {
135                    q.poll();
136                    q.offer(curr);
137                }
138            }
139
140            ArrayList<Pair<String, Double>> b = new ArrayList<>();
141            while (q.size() > 0) {
142                b.add(q.poll());
143            }
144
145            Collections.reverse(b);
146            map.put(dimensions[i], b);
147        }
148
149        return map;
150    }
151
152    @Override
153    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
154        Prediction<Regressor> prediction = predict(example);
155        Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
156
157        SparseVector features = createFeatures(example);
158        for (int i = 0; i < dimensions.length; i++) {
159            List<Pair<String, Double>> classScores = new ArrayList<>();
160            for (VectorTuple f : features) {
161                double score = weights[i].get(f.index) * f.value;
162                classScores.add(new Pair<>(featureIDMap.get(f.index).getName(), score));
163            }
164            classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
165            weightMap.put(dimensions[i], classScores);
166        }
167
168        return Optional.of(new Excuse<>(example, prediction, weightMap));
169    }
170
171    @Override
172    protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) {
173        return new SparseLinearModel(newName,Arrays.copyOf(dimensions,dimensions.length),
174                newProvenance,featureIDMap,outputIDInfo,
175                copyWeights(),
176                featureMeans.copy(),featureVariance.copy(),
177                Arrays.copyOf(yMean,yMean.length), Arrays.copyOf(yVariance,yVariance.length), bias);
178    }
179
180    private SparseVector[] copyWeights() {
181        SparseVector[] newWeights = new SparseVector[weights.length];
182
183        for (int i = 0; i < weights.length; i++) {
184            newWeights[i] = weights[i].copy();
185        }
186
187        return newWeights;
188    }
189
190    /**
191     * Gets a copy of the model parameters.
192     * @return A map from the dimension name to the model parameters.
193     */
194    public Map<String,SparseVector> getWeights() {
195        SparseVector[] newWeights = copyWeights();
196        Map<String,SparseVector> output = new HashMap<>();
197        for (int i = 0; i < dimensions.length; i++) {
198            output.put(dimensions[i],newWeights[i]);
199        }
200        return output;
201    }
202}