001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.baseline;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.classification.ImmutableLabelInfo;
027import org.tribuo.classification.Label;
028import org.tribuo.classification.LabelFactory;
029import org.tribuo.classification.baseline.DummyClassifierTrainer.DummyType;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.util.Util;
032
033import java.util.Collections;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.Optional;
038import java.util.Random;
039
040import static org.tribuo.Trainer.DEFAULT_SEED;
041
042/**
043 * A model which performs dummy classifications (e.g., constant output, uniform sampled labels, stratified sampled labels).
044 */
045public class DummyClassifierModel extends Model<Label> {
046    private static final long serialVersionUID = 1L;
047
048    private final DummyType dummyType;
049
050    private final Label constantLabel;
051
052    private final double[] cdf;
053
054    private final Random rng;
055
056    private final long seed;
057
058    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo) {
059        super("dummy-MOST_FREQUENT-classifier", description, featureIDMap, outputIDInfo, false);
060        this.dummyType = DummyType.MOST_FREQUENT;
061        this.constantLabel = findMostFrequentLabel(outputIDInfo);
062        this.cdf = null;
063        this.seed = DEFAULT_SEED;
064        this.rng = null;
065    }
066
067    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, DummyType dummyType, long seed) {
068        super("dummy-"+dummyType+"-classifier", description, featureIDMap, outputIDInfo, false);
069        this.dummyType = dummyType;
070        this.constantLabel = LabelFactory.UNKNOWN_LABEL;
071        this.cdf = dummyType == DummyType.UNIFORM ? generateUniformCDF(outputIDInfo) : generateStratifiedCDF(outputIDInfo);
072        this.seed = seed;
073        this.rng = new Random(seed);
074    }
075
076    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, Label constantLabel) {
077        super("dummy-CONSTANT-classifier", description, featureIDMap, outputIDInfo, false);
078        this.dummyType = DummyType.CONSTANT;
079        this.constantLabel = constantLabel;
080        this.cdf = null;
081        this.seed = DEFAULT_SEED;
082        this.rng = null;
083    }
084
085    @Override
086    public Prediction<Label> predict(Example<Label> example) {
087        switch (dummyType) {
088            case CONSTANT:
089            case MOST_FREQUENT:
090                return new Prediction<>(constantLabel,0,example);
091            case UNIFORM:
092            case STRATIFIED:
093                return new Prediction<>(sampleLabel(cdf,outputIDInfo,rng),0,example);
094            default:
095                throw new IllegalStateException("Unknown dummyType " + dummyType);
096        }
097    }
098
099    @Override
100    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
101        Map<String,List<Pair<String,Double>>> map = new HashMap<>();
102        if (n != 0) {
103            map.put(Model.ALL_OUTPUTS, Collections.singletonList(new Pair<>(BIAS_FEATURE, 1.0)));
104        }
105        return map;
106    }
107
108    @Override
109    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
110        return Optional.of(new Excuse<>(example,predict(example),getTopFeatures(1)));
111    }
112
113    @Override
114    protected DummyClassifierModel copy(String newName, ModelProvenance newProvenance) {
115        switch (dummyType) {
116            case CONSTANT:
117                return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo,constantLabel.copy());
118            case MOST_FREQUENT:
119                return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo);
120            case UNIFORM:
121            case STRATIFIED:
122                return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo,dummyType,seed);
123            default:
124                throw new IllegalStateException("Unknown dummyType " + dummyType);
125        }
126    }
127
128    /**
129     * Samples a label from the supplied CDF.
130     * @param cdf The CDF to sample from.
131     * @param outputIDInfo The mapping from label ids to values.
132     * @param rng The RNG to use.
133     * @return A Label.
134     */
135    private static Label sampleLabel(double[] cdf, ImmutableOutputInfo<Label> outputIDInfo, Random rng) {
136        int sample = Util.sampleFromCDF(cdf,rng);
137        return outputIDInfo.getOutput(sample);
138    }
139
140    /**
141     * Finds the most frequent label and returns it.
142     * @param outputInfo The output information (must be a subclass of ImmutableLabelInfo).
143     * @return The most frequent label.
144     */
145    private static Label findMostFrequentLabel(ImmutableOutputInfo<Label> outputInfo) {
146        Label maxLabel = null;
147        long count = -1;
148
149        ImmutableLabelInfo labelInfo = (ImmutableLabelInfo) outputInfo;
150
151        for (Pair<Integer,Label> p : labelInfo) {
152            long curCount = labelInfo.getLabelCount(p.getA());
153            if (curCount > count) {
154                count = curCount;
155                maxLabel = p.getB();
156            }
157        }
158
159        return maxLabel;
160    }
161
162    /**
163     * Generates a uniform CDF for the supplied labels.
164     * @param outputInfo The output information.
165     * @return A uniform CDF across the domain.
166     */
167    private static double[] generateUniformCDF(ImmutableOutputInfo<Label> outputInfo) {
168        int length = outputInfo.getDomain().size();
169        double[] pmf = Util.generateUniformVector(length,1.0/length);
170        return Util.generateCDF(pmf);
171    }
172
173    /**
174     * Generates a CDF where the label probabilities are proportional to their observed counts.
175     * @param outputInfo The output information.
176     * @return A CDF proportional to the observed counts.
177     */
178    private static double[] generateStratifiedCDF(ImmutableOutputInfo<Label> outputInfo) {
179        ImmutableLabelInfo labelInfo = (ImmutableLabelInfo) outputInfo;
180        int length = labelInfo.getDomain().size();
181        long counts = labelInfo.getTotalObservations();
182
183        double[] pmf = new double[length];
184
185        for (Pair<Integer,Label> p : labelInfo) {
186            int idx = p.getA();
187            pmf[idx] = labelInfo.getLabelCount(idx) / (double) counts;
188        }
189
190        return Util.generateCDF(pmf);
191    }
192}