001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.mnb;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.Feature;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Trainer;
029import org.tribuo.WeightedExamples;
030import org.tribuo.classification.Label;
031import org.tribuo.math.la.DenseSparseMatrix;
032import org.tribuo.math.la.SparseVector;
033import org.tribuo.provenance.ModelProvenance;
034import org.tribuo.provenance.TrainerProvenance;
035import org.tribuo.provenance.impl.TrainerProvenanceImpl;
036
037import java.time.OffsetDateTime;
038import java.util.HashMap;
039import java.util.Map;
040
041/**
042 * A {@link Trainer} which trains a multinomial Naive Bayes model with Laplace smoothing.
043 * <p>
044 * All feature values must be non-negative.
045 * <p>
046 * See:
047 * <pre>
048 * Wang S, Manning CD.
049 * "Baselines and Bigrams: Simple, Good Sentiment and Topic Classification"
050 * Proceedings of the 50th Annual Meeting of the Association for Computational Linguistics, 2012.
051 * </pre>
052 */
053public class MultinomialNaiveBayesTrainer implements Trainer<Label>, WeightedExamples {
054
055    @Config(description="Smoothing parameter.")
056    private double alpha = 1.0;
057
058    private int invocationCount = 0;
059
060    public MultinomialNaiveBayesTrainer() {
061        this(1.0);
062    }
063
064    //TODO support different alphas for different features?
065    public MultinomialNaiveBayesTrainer(double alpha) {
066        if(alpha <= 0.0) {
067            throw new IllegalArgumentException("alpha parameter must be > 0");
068        }
069        this.alpha = alpha;
070    }
071
072    @Override
073    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
074        if (examples.getOutputInfo().getUnknownCount() > 0) {
075            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
076        }
077        ImmutableOutputInfo<Label> labelInfos = examples.getOutputIDInfo();
078        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
079
080        Map<Integer, Map<Integer, Double>> labelWeights = new HashMap<>();
081
082        for (Pair<Integer,Label> label : labelInfos) {
083            labelWeights.put(label.getA(), new HashMap<>());
084        }
085
086        for (Example<Label> ex : examples) {
087            int idx = labelInfos.getID(ex.getOutput());
088            Map<Integer, Double> featureMap = labelWeights.get(idx);
089            double curWeight = ex.getWeight();
090            for (Feature feat : ex) {
091                if (feat.getValue() < 0.0) {
092                    throw new IllegalStateException("Multinomial Naive Bayes requires non-negative features. Found feature " + feat.toString());
093                }
094                featureMap.merge(featureInfos.getID(feat.getName()), curWeight*feat.getValue(), Double::sum);
095            }
096        }
097
098        TrainerProvenance trainerProvenance = getProvenance();
099        ModelProvenance provenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
100        invocationCount++;
101
102        SparseVector[] labelVectors = new SparseVector[labelInfos.size()];
103
104        for(int i = 0; i < labelInfos.size(); i++) {
105            SparseVector sv = SparseVector.createSparseVector(featureInfos.size(), labelWeights.get(i));
106            double unsmoothedZ = sv.oneNorm();
107            sv.foreachInPlace(d -> Math.log((d + alpha) / (unsmoothedZ + (featureInfos.size() * alpha))));
108            labelVectors[i] = sv;
109        }
110
111        DenseSparseMatrix labelWordProbs = DenseSparseMatrix.createFromSparseVectors(labelVectors);
112
113        return new MultinomialNaiveBayesModel("", provenance, featureInfos, labelInfos, labelWordProbs, alpha);
114    }
115
116    @Override
117    public int getInvocationCount() {
118        return invocationCount;
119    }
120
121    @Override
122    public String toString() {
123        return "MultinomialNaiveBayesTrainer(alpha=" + alpha + ")";
124    }
125
126    @Override
127    public TrainerProvenance getProvenance() {
128        return new TrainerProvenanceImpl(this);
129    }
130}