001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.mnb; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Prediction; 027import org.tribuo.classification.Label; 028import org.tribuo.math.la.DenseSparseMatrix; 029import org.tribuo.math.la.DenseVector; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.math.la.VectorTuple; 032import org.tribuo.math.util.ExpNormalizer; 033import org.tribuo.math.util.VectorNormalizer; 034import org.tribuo.provenance.ModelProvenance; 035 036import java.util.ArrayList; 037import java.util.Comparator; 038import java.util.HashMap; 039import java.util.LinkedHashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.Optional; 043 044/** 045 * A {@link Model} for multinomial Naive Bayes with Laplace smoothing. 046 * <p> 047 * All feature values must be non-negative, otherwise it will throw IllegalArgumentException. 048 * <p> 049 * See: 050 * <pre> 051 * Wang S, Manning CD. 052 * "Baselines and Bigrams: Simple, Good Sentiment and Topic Classification" 053 * Proceedings of the 50th Annual Meeting of the Association for Computational Linguistics, 2012. 054 * </pre> 055 */ 056public class MultinomialNaiveBayesModel extends Model<Label> { 057 private static final long serialVersionUID = 1L; 058 059 private final DenseSparseMatrix labelWordProbs; 060 private final double alpha; 061 062 private static final VectorNormalizer normalizer = new ExpNormalizer(); 063 064 MultinomialNaiveBayesModel(String name, ModelProvenance description, ImmutableFeatureMap featureInfos, ImmutableOutputInfo<Label> labelInfos, DenseSparseMatrix labelWordProbs, double alpha) { 065 super(name, description, featureInfos, labelInfos, true); 066 this.labelWordProbs = labelWordProbs; 067 this.alpha = alpha; 068 } 069 070 @Override 071 public Prediction<Label> predict(Example<Label> example) { 072 SparseVector exVector = SparseVector.createSparseVector(example, featureIDMap, false); 073 074 if (exVector.minValue() < 0.0) { 075 throw new IllegalArgumentException("Example has negative feature values, example = " + example.toString()); 076 } 077 if (exVector.numActiveElements() == 0) { 078 throw new IllegalArgumentException("No features found in Example " + example.toString()); 079 } 080 081 /* Since we keep the label by feature matrix sparse, we need to manually 082 * add the weights contributed by smoothing unobserved features. We need to 083 * add in the portion of the inner product for the indices that are active 084 * in the example but are not active in the labelWordProbs matrix (but are 085 * still non-zero due to smoothing). 086 */ 087 double[] alphaOffsets = new double[outputIDInfo.size()]; 088 int vocabSize = labelWordProbs.getDimension2Size(); 089 if (alpha > 0.0) { 090 for (int i = 0; i < outputIDInfo.size(); i++) { 091 double unobservedProb = Math.log(alpha / (labelWordProbs.getRow(i).oneNorm() + (vocabSize * alpha))); 092 int[] mismatchedIndices = exVector.difference(labelWordProbs.getRow(i)); 093 double inExampleFactor = 0.0; 094 for (int idx = 0; idx < mismatchedIndices.length; idx++) { 095 // TODO - exVector.get is slow as it does a binary search into the vector. 096 inExampleFactor += exVector.get(mismatchedIndices[idx]) * unobservedProb; 097 } 098 alphaOffsets[i] = inExampleFactor; 099 } 100 } 101 102 DenseVector prediction = labelWordProbs.leftMultiply(exVector); 103 prediction.intersectAndAddInPlace(DenseVector.createDenseVector(alphaOffsets)); 104 prediction.normalize(normalizer); 105 Map<String,Label> distribution = new LinkedHashMap<>(); 106 Label maxLabel = null; 107 double maxScore = Double.NEGATIVE_INFINITY; 108 for(VectorTuple vt : prediction) { 109 String name = outputIDInfo.getOutput(vt.index).getLabel(); 110 Label label = new Label(name, vt.value); 111 if (vt.value > maxScore) { 112 maxScore = vt.value; 113 maxLabel = label; 114 } 115 distribution.put(name,label); 116 } 117 Prediction<Label> p = new Prediction<>(maxLabel, distribution, exVector.numActiveElements(), example, true); 118 return p; 119 } 120 121 @Override 122 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 123 int maxFeatures = n < 0 ? featureIDMap.size() : n; 124 Map<String, List<Pair<String, Double>>> topFeatures = new HashMap<>(); 125 126 for (Pair<Integer,Label> label : outputIDInfo) { 127 List<Pair<String, Double>> features = new ArrayList<>(labelWordProbs.numActiveElements(label.getA())); 128 for(VectorTuple vt : labelWordProbs.getRow(label.getA())) { 129 features.add(new Pair<>(featureIDMap.get(vt.index).getName(), vt.value)); 130 } 131 features.sort(Comparator.comparing(x -> -x.getB())); 132 if(maxFeatures < featureIDMap.size()) { 133 features = features.subList(0, maxFeatures); 134 } 135 topFeatures.put(label.getB().getLabel(), features); 136 } 137 return topFeatures; 138 } 139 140 @Override 141 public Optional<Excuse<Label>> getExcuse(Example<Label> example) { 142 Map<String, List<Pair<String, Double>>> explanation = new HashMap<>(); 143 for (Pair<Integer,Label> label : outputIDInfo) { 144 List<Pair<String, Double>> scores = new ArrayList<>(); 145 for(Feature f : example) { 146 int id = featureIDMap.getID(f.getName()); 147 if (id > -1) { 148 scores.add(new Pair<>(f.getName(),labelWordProbs.getRow(label.getA()).get(id))); 149 } 150 } 151 explanation.put(label.getB().getLabel(), scores); 152 } 153 return Optional.of(new Excuse<>(example, predict(example), explanation)); 154 } 155 156 @Override 157 protected MultinomialNaiveBayesModel copy(String newName, ModelProvenance newProvenance) { 158 return new MultinomialNaiveBayesModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseSparseMatrix(labelWordProbs),alpha); 159 } 160}