001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.linear;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Prediction;
027import org.tribuo.classification.Label;
028import org.tribuo.math.LinearParameters;
029import org.tribuo.math.la.DenseMatrix;
030import org.tribuo.math.la.DenseVector;
031import org.tribuo.math.la.SparseVector;
032import org.tribuo.math.util.VectorNormalizer;
033import org.tribuo.provenance.ModelProvenance;
034
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.Comparator;
038import java.util.HashMap;
039import java.util.LinkedHashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.Optional;
043import java.util.PriorityQueue;
044
045/**
046 * The inference time version of a linear model trained using SGD.
047 * <p>
048 * See:
049 * <pre>
050 * Bottou L.
051 * "Large-Scale Machine Learning with Stochastic Gradient Descent"
052 * Proceedings of COMPSTAT, 2010.
053 * </pre>
054 */
055public class LinearSGDModel extends Model<Label> {
056    private static final long serialVersionUID = 2L;
057
058    private final DenseMatrix weights;
059    private final VectorNormalizer normalizer;
060
061    LinearSGDModel(String name, ModelProvenance description,
062                   ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap,
063                   LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) {
064        super(name, description, featureIDMap, labelIDMap, generatesProbabilities);
065        this.weights = parameters.getWeightMatrix();
066        this.normalizer = normalizer;
067    }
068
069    private LinearSGDModel(String name, ModelProvenance description,
070                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap,
071                          DenseMatrix weights, VectorNormalizer normalizer, boolean generatesProbabilities) {
072        super(name, description, featureIDMap, labelIDMap, generatesProbabilities);
073        this.weights = weights;
074        this.normalizer = normalizer;
075    }
076
077    @Override
078    public Prediction<Label> predict(Example<Label> example) {
079        SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true);
080        // Due to bias feature
081        if (features.numActiveElements() == 1) {
082            throw new IllegalArgumentException("No features found in Example " + example.toString());
083        }
084        DenseVector prediction = weights.leftMultiply(features);
085        prediction.normalize(normalizer);
086
087        double maxScore = Double.NEGATIVE_INFINITY;
088        Label maxLabel = null;
089        Map<String,Label> predMap = new LinkedHashMap<>();
090        for (int i = 0; i < prediction.size(); i++) {
091            String labelName = outputIDInfo.getOutput(i).getLabel();
092            Label label = new Label(labelName, prediction.get(i));
093            predMap.put(labelName,label);
094            if (label.getScore() > maxScore) {
095                maxScore = label.getScore();
096                maxLabel = label;
097            }
098        }
099        return new Prediction<>(maxLabel, predMap, features.numActiveElements()-1, example, generatesProbabilities);
100    }
101
102    @Override
103    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
104        int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n;
105
106        Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
107
108        //
109        // Use a priority queue to find the top N features.
110        int numClasses = weights.getDimension1Size();
111        int numFeatures = weights.getDimension2Size()-1; //Removing the bias feature.
112        Map<String, List<Pair<String,Double>>> map = new HashMap<>();
113        for (int i = 0; i < numClasses; i++) {
114            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
115
116            for (int j = 0; j < numFeatures; j++) {
117                Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), weights.get(i,j));
118
119                if (q.size() < maxFeatures) {
120                    q.offer(curr);
121                } else if (comparator.compare(curr, q.peek()) > 0) {
122                    q.poll();
123                    q.offer(curr);
124                }
125            }
126            Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures));
127
128            if (q.size() < maxFeatures) {
129                q.offer(curr);
130            } else if (comparator.compare(curr, q.peek()) > 0) {
131                q.poll();
132                q.offer(curr);
133            }
134            ArrayList<Pair<String,Double>> b = new ArrayList<>();
135            while (q.size() > 0) {
136                b.add(q.poll());
137            }
138
139            Collections.reverse(b);
140            map.put(outputIDInfo.getOutput(i).getLabel(), b);
141        }
142        return map;
143    }
144
145    @Override
146    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
147        Prediction<Label> prediction = predict(example);
148        Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
149        int numClasses = weights.getDimension1Size();
150        int numFeatures = weights.getDimension2Size()-1;
151
152        for (int i = 0; i < numClasses; i++) {
153            List<Pair<String, Double>> classScores = new ArrayList<>();
154            for (Feature f : example) {
155                int id = featureIDMap.getID(f.getName());
156                if (id > -1) {
157                    double score = weights.get(i,id) * f.getValue();
158                    classScores.add(new Pair<>(f.getName(), score));
159                }
160            }
161            classScores.add(new Pair<>(Model.BIAS_FEATURE,weights.get(i,numFeatures)));
162            classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
163            weightMap.put(outputIDInfo.getOutput(i).getLabel(), classScores);
164        }
165
166        return Optional.of(new Excuse<>(example, prediction, weightMap));
167    }
168
169    @Override
170    protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
171        return new LinearSGDModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseMatrix(weights),normalizer,generatesProbabilities);
172    }
173}