001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.sgd.linear;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Prediction;
027import org.tribuo.math.LinearParameters;
028import org.tribuo.math.la.DenseMatrix;
029import org.tribuo.math.la.DenseVector;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.regression.Regressor;
033
034import java.util.ArrayList;
035import java.util.Arrays;
036import java.util.Collections;
037import java.util.Comparator;
038import java.util.HashMap;
039import java.util.List;
040import java.util.Map;
041import java.util.Optional;
042import java.util.PriorityQueue;
043
044/**
045 * The inference time version of a linear model trained using SGD.
046 * The output dimensions are independent, unless they are tied together by the
047 * optimiser.
048 * <p>
049 * See:
050 * <pre>
051 * Bottou L.
052 * "Large-Scale Machine Learning with Stochastic Gradient Descent"
053 * Proceedings of COMPSTAT, 2010.
054 * </pre>
055 */
056public class LinearSGDModel extends Model<Regressor> {
057    private static final long serialVersionUID = 3L;
058
059    private final String[] dimensionNames;
060    private final DenseMatrix weights;
061
062    LinearSGDModel(String name, String[] dimensionNames, ModelProvenance description,
063                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
064                          LinearParameters parameters) {
065        super(name, description, featureIDMap, labelIDMap, false);
066        this.weights = parameters.getWeightMatrix();
067        this.dimensionNames = dimensionNames;
068    }
069
070    private LinearSGDModel(String name, String[] dimensionNames, ModelProvenance description,
071                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
072                          DenseMatrix weights) {
073        super(name, description, featureIDMap, labelIDMap, false);
074        this.weights = weights;
075        this.dimensionNames = dimensionNames;
076    }
077
078    @Override
079    public Prediction<Regressor> predict(Example<Regressor> example) {
080        SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true);
081        if (features.numActiveElements() == 1) {
082            throw new IllegalArgumentException("No features found in Example " + example.toString());
083        }
084        DenseVector prediction = weights.leftMultiply(features);
085        return new Prediction<>(new Regressor(dimensionNames,prediction.toArray()), features.numActiveElements(), example);
086    }
087
088    @Override
089    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
090        int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n;
091
092        Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
093
094        //
095        // Use a priority queue to find the top N features.
096        int numClasses = weights.getDimension1Size();
097        int numFeatures = weights.getDimension2Size()-1; //Removing the bias feature.
098        Map<String, List<Pair<String,Double>>> map = new HashMap<>();
099        for (int i = 0; i < numClasses; i++) {
100            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
101
102            for (int j = 0; j < numFeatures; j++) {
103                Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), weights.get(i,j));
104
105                if (q.size() < maxFeatures) {
106                    q.offer(curr);
107                } else if (comparator.compare(curr, q.peek()) > 0) {
108                    q.poll();
109                    q.offer(curr);
110                }
111            }
112            Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures));
113
114            if (q.size() < maxFeatures) {
115                q.offer(curr);
116            } else if (comparator.compare(curr, q.peek()) > 0) {
117                q.poll();
118                q.offer(curr);
119            }
120            ArrayList<Pair<String,Double>> b = new ArrayList<>();
121            while (q.size() > 0) {
122                b.add(q.poll());
123            }
124
125            Collections.reverse(b);
126            map.put(dimensionNames[i], b);
127        }
128        return map;
129    }
130
131    @Override
132    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
133        Prediction<Regressor> prediction = predict(example);
134        Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
135        int numOutputs = weights.getDimension1Size();
136        int numFeatures = weights.getDimension2Size()-1; //Remove bias feature
137
138        for (int i = 0; i < numOutputs; i++) {
139            List<Pair<String, Double>> classScores = new ArrayList<>();
140            for (Feature f : example) {
141                int id = featureIDMap.getID(f.getName());
142                if (id > -1) {
143                    double score = weights.get(i,id) * f.getValue();
144                    classScores.add(new Pair<>(f.getName(), score));
145                }
146            }
147            classScores.add(new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures)));
148            classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
149            weightMap.put(dimensionNames[i], classScores);
150        }
151
152        return Optional.of(new Excuse<>(example, prediction, weightMap));
153    }
154
155    @Override
156    protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
157        return new LinearSGDModel(newName,Arrays.copyOf(dimensionNames,dimensionNames.length),newProvenance,featureIDMap,outputIDInfo,getWeightsCopy());
158    }
159
160    public DenseMatrix getWeightsCopy() {
161        return weights.copy();
162    }
163}