001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.mnb; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.Feature; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Trainer; 029import org.tribuo.WeightedExamples; 030import org.tribuo.classification.Label; 031import org.tribuo.math.la.DenseSparseMatrix; 032import org.tribuo.math.la.SparseVector; 033import org.tribuo.provenance.ModelProvenance; 034import org.tribuo.provenance.TrainerProvenance; 035import org.tribuo.provenance.impl.TrainerProvenanceImpl; 036 037import java.time.OffsetDateTime; 038import java.util.HashMap; 039import java.util.Map; 040 041/** 042 * A {@link Trainer} which trains a multinomial Naive Bayes model with Laplace smoothing. 043 * <p> 044 * All feature values must be non-negative. 045 * <p> 046 * See: 047 * <pre> 048 * Wang S, Manning CD. 049 * "Baselines and Bigrams: Simple, Good Sentiment and Topic Classification" 050 * Proceedings of the 50th Annual Meeting of the Association for Computational Linguistics, 2012. 051 * </pre> 052 */ 053public class MultinomialNaiveBayesTrainer implements Trainer<Label>, WeightedExamples { 054 055 @Config(description="Smoothing parameter.") 056 private double alpha = 1.0; 057 058 private int invocationCount = 0; 059 060 public MultinomialNaiveBayesTrainer() { 061 this(1.0); 062 } 063 064 //TODO support different alphas for different features? 065 public MultinomialNaiveBayesTrainer(double alpha) { 066 if(alpha <= 0.0) { 067 throw new IllegalArgumentException("alpha parameter must be > 0"); 068 } 069 this.alpha = alpha; 070 } 071 072 @Override 073 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 074 if (examples.getOutputInfo().getUnknownCount() > 0) { 075 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 076 } 077 ImmutableOutputInfo<Label> labelInfos = examples.getOutputIDInfo(); 078 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 079 080 Map<Integer, Map<Integer, Double>> labelWeights = new HashMap<>(); 081 082 for (Pair<Integer,Label> label : labelInfos) { 083 labelWeights.put(label.getA(), new HashMap<>()); 084 } 085 086 for (Example<Label> ex : examples) { 087 int idx = labelInfos.getID(ex.getOutput()); 088 Map<Integer, Double> featureMap = labelWeights.get(idx); 089 double curWeight = ex.getWeight(); 090 for (Feature feat : ex) { 091 if (feat.getValue() < 0.0) { 092 throw new IllegalStateException("Multinomial Naive Bayes requires non-negative features. Found feature " + feat.toString()); 093 } 094 featureMap.merge(featureInfos.getID(feat.getName()), curWeight*feat.getValue(), Double::sum); 095 } 096 } 097 098 TrainerProvenance trainerProvenance = getProvenance(); 099 ModelProvenance provenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 100 invocationCount++; 101 102 SparseVector[] labelVectors = new SparseVector[labelInfos.size()]; 103 104 for(int i = 0; i < labelInfos.size(); i++) { 105 SparseVector sv = SparseVector.createSparseVector(featureInfos.size(), labelWeights.get(i)); 106 double unsmoothedZ = sv.oneNorm(); 107 sv.foreachInPlace(d -> Math.log((d + alpha) / (unsmoothedZ + (featureInfos.size() * alpha)))); 108 labelVectors[i] = sv; 109 } 110 111 DenseSparseMatrix labelWordProbs = DenseSparseMatrix.createFromSparseVectors(labelVectors); 112 113 return new MultinomialNaiveBayesModel("", provenance, featureInfos, labelInfos, labelWordProbs, alpha); 114 } 115 116 @Override 117 public int getInvocationCount() { 118 return invocationCount; 119 } 120 121 @Override 122 public String toString() { 123 return "MultinomialNaiveBayesTrainer(alpha=" + alpha + ")"; 124 } 125 126 @Override 127 public TrainerProvenance getProvenance() { 128 return new TrainerProvenanceImpl(this); 129 } 130}