001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.mnb;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Prediction;
027import org.tribuo.classification.Label;
028import org.tribuo.math.la.DenseSparseMatrix;
029import org.tribuo.math.la.DenseVector;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.math.la.VectorTuple;
032import org.tribuo.math.util.ExpNormalizer;
033import org.tribuo.math.util.VectorNormalizer;
034import org.tribuo.provenance.ModelProvenance;
035
036import java.util.ArrayList;
037import java.util.Comparator;
038import java.util.HashMap;
039import java.util.LinkedHashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.Optional;
043
044/**
045 * A {@link Model} for multinomial Naive Bayes with Laplace smoothing.
046 * <p>
047 * All feature values must be non-negative, otherwise it will throw IllegalArgumentException.
048 * <p>
049 * See:
050 * <pre>
051 * Wang S, Manning CD.
052 * "Baselines and Bigrams: Simple, Good Sentiment and Topic Classification"
053 * Proceedings of the 50th Annual Meeting of the Association for Computational Linguistics, 2012.
054 * </pre>
055 */
056public class MultinomialNaiveBayesModel extends Model<Label> {
057    private static final long serialVersionUID = 1L;
058
059    private final DenseSparseMatrix labelWordProbs;
060    private final double alpha;
061
062    private static final VectorNormalizer normalizer = new ExpNormalizer();
063
064    MultinomialNaiveBayesModel(String name, ModelProvenance description, ImmutableFeatureMap featureInfos, ImmutableOutputInfo<Label> labelInfos, DenseSparseMatrix labelWordProbs, double alpha) {
065        super(name, description, featureInfos, labelInfos, true);
066        this.labelWordProbs = labelWordProbs;
067        this.alpha = alpha;
068    }
069
070    @Override
071    public Prediction<Label> predict(Example<Label> example) {
072        SparseVector exVector = SparseVector.createSparseVector(example, featureIDMap, false);
073
074        if (exVector.minValue() < 0.0) {
075            throw new IllegalArgumentException("Example has negative feature values, example = " + example.toString());
076        }
077        if (exVector.numActiveElements() == 0) {
078            throw new IllegalArgumentException("No features found in Example " + example.toString());
079        }
080
081        /* Since we keep the label by feature matrix sparse, we need to manually
082         * add the weights contributed by smoothing unobserved features. We need to
083         * add in the portion of the inner product for the indices that are active
084         * in the example but are not active in the labelWordProbs matrix (but are
085         * still non-zero due to smoothing).
086         */
087        double[] alphaOffsets = new double[outputIDInfo.size()];
088        int vocabSize = labelWordProbs.getDimension2Size();
089        if (alpha > 0.0) {
090            for (int i = 0; i < outputIDInfo.size(); i++) {
091                double unobservedProb = Math.log(alpha / (labelWordProbs.getRow(i).oneNorm() + (vocabSize * alpha)));
092                int[] mismatchedIndices = exVector.difference(labelWordProbs.getRow(i));
093                double inExampleFactor = 0.0;
094                for (int idx = 0; idx < mismatchedIndices.length; idx++) {
095                    // TODO - exVector.get is slow as it does a binary search into the vector.
096                    inExampleFactor += exVector.get(mismatchedIndices[idx]) * unobservedProb;
097                }
098                alphaOffsets[i] = inExampleFactor;
099            }
100        }
101
102        DenseVector prediction = labelWordProbs.leftMultiply(exVector);
103        prediction.intersectAndAddInPlace(DenseVector.createDenseVector(alphaOffsets));
104        prediction.normalize(normalizer);
105        Map<String,Label> distribution = new LinkedHashMap<>();
106        Label maxLabel = null;
107        double maxScore = Double.NEGATIVE_INFINITY;
108        for(VectorTuple vt : prediction) {
109            String name = outputIDInfo.getOutput(vt.index).getLabel();
110            Label label = new Label(name, vt.value);
111            if (vt.value > maxScore) {
112                maxScore = vt.value;
113                maxLabel = label;
114            }
115            distribution.put(name,label);
116        }
117        Prediction<Label> p = new Prediction<>(maxLabel, distribution, exVector.numActiveElements(), example, true);
118        return p;
119    }
120
121    @Override
122    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
123        int maxFeatures = n < 0 ? featureIDMap.size() : n;
124        Map<String, List<Pair<String, Double>>> topFeatures = new HashMap<>();
125
126        for (Pair<Integer,Label> label : outputIDInfo) {
127            List<Pair<String, Double>> features = new ArrayList<>(labelWordProbs.numActiveElements(label.getA()));
128            for(VectorTuple vt : labelWordProbs.getRow(label.getA())) {
129                features.add(new Pair<>(featureIDMap.get(vt.index).getName(), vt.value));
130            }
131            features.sort(Comparator.comparing(x -> -x.getB()));
132            if(maxFeatures < featureIDMap.size()) {
133                features = features.subList(0, maxFeatures);
134            }
135            topFeatures.put(label.getB().getLabel(), features);
136        }
137        return topFeatures;
138    }
139
140    @Override
141    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
142        Map<String, List<Pair<String, Double>>> explanation = new HashMap<>();
143        for (Pair<Integer,Label> label : outputIDInfo) {
144            List<Pair<String, Double>> scores = new ArrayList<>();
145            for(Feature f : example) {
146                int id = featureIDMap.getID(f.getName());
147                if (id > -1) {
148                    scores.add(new Pair<>(f.getName(),labelWordProbs.getRow(label.getA()).get(id)));
149                }
150            }
151            explanation.put(label.getB().getLabel(), scores);
152        }
153        return Optional.of(new Excuse<>(example, predict(example), explanation));
154    }
155
156    @Override
157    protected MultinomialNaiveBayesModel copy(String newName, ModelProvenance newProvenance) {
158        return new MultinomialNaiveBayesModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseSparseMatrix(labelWordProbs),alpha);
159    }
160}