001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.baseline; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.classification.ImmutableLabelInfo; 027import org.tribuo.classification.Label; 028import org.tribuo.classification.LabelFactory; 029import org.tribuo.classification.baseline.DummyClassifierTrainer.DummyType; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.util.Util; 032 033import java.util.Collections; 034import java.util.HashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.Optional; 038import java.util.Random; 039 040import static org.tribuo.Trainer.DEFAULT_SEED; 041 042/** 043 * A model which performs dummy classifications (e.g., constant output, uniform sampled labels, stratified sampled labels). 044 */ 045public class DummyClassifierModel extends Model<Label> { 046 private static final long serialVersionUID = 1L; 047 048 private final DummyType dummyType; 049 050 private final Label constantLabel; 051 052 private final double[] cdf; 053 054 private final Random rng; 055 056 private final long seed; 057 058 DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo) { 059 super("dummy-MOST_FREQUENT-classifier", description, featureIDMap, outputIDInfo, false); 060 this.dummyType = DummyType.MOST_FREQUENT; 061 this.constantLabel = findMostFrequentLabel(outputIDInfo); 062 this.cdf = null; 063 this.seed = DEFAULT_SEED; 064 this.rng = null; 065 } 066 067 DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, DummyType dummyType, long seed) { 068 super("dummy-"+dummyType+"-classifier", description, featureIDMap, outputIDInfo, false); 069 this.dummyType = dummyType; 070 this.constantLabel = LabelFactory.UNKNOWN_LABEL; 071 this.cdf = dummyType == DummyType.UNIFORM ? generateUniformCDF(outputIDInfo) : generateStratifiedCDF(outputIDInfo); 072 this.seed = seed; 073 this.rng = new Random(seed); 074 } 075 076 DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, Label constantLabel) { 077 super("dummy-CONSTANT-classifier", description, featureIDMap, outputIDInfo, false); 078 this.dummyType = DummyType.CONSTANT; 079 this.constantLabel = constantLabel; 080 this.cdf = null; 081 this.seed = DEFAULT_SEED; 082 this.rng = null; 083 } 084 085 @Override 086 public Prediction<Label> predict(Example<Label> example) { 087 switch (dummyType) { 088 case CONSTANT: 089 case MOST_FREQUENT: 090 return new Prediction<>(constantLabel,0,example); 091 case UNIFORM: 092 case STRATIFIED: 093 return new Prediction<>(sampleLabel(cdf,outputIDInfo,rng),0,example); 094 default: 095 throw new IllegalStateException("Unknown dummyType " + dummyType); 096 } 097 } 098 099 @Override 100 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 101 Map<String,List<Pair<String,Double>>> map = new HashMap<>(); 102 if (n != 0) { 103 map.put(Model.ALL_OUTPUTS, Collections.singletonList(new Pair<>(BIAS_FEATURE, 1.0))); 104 } 105 return map; 106 } 107 108 @Override 109 public Optional<Excuse<Label>> getExcuse(Example<Label> example) { 110 return Optional.of(new Excuse<>(example,predict(example),getTopFeatures(1))); 111 } 112 113 @Override 114 protected DummyClassifierModel copy(String newName, ModelProvenance newProvenance) { 115 switch (dummyType) { 116 case CONSTANT: 117 return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo,constantLabel.copy()); 118 case MOST_FREQUENT: 119 return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo); 120 case UNIFORM: 121 case STRATIFIED: 122 return new DummyClassifierModel(newProvenance,featureIDMap,outputIDInfo,dummyType,seed); 123 default: 124 throw new IllegalStateException("Unknown dummyType " + dummyType); 125 } 126 } 127 128 /** 129 * Samples a label from the supplied CDF. 130 * @param cdf The CDF to sample from. 131 * @param outputIDInfo The mapping from label ids to values. 132 * @param rng The RNG to use. 133 * @return A Label. 134 */ 135 private static Label sampleLabel(double[] cdf, ImmutableOutputInfo<Label> outputIDInfo, Random rng) { 136 int sample = Util.sampleFromCDF(cdf,rng); 137 return outputIDInfo.getOutput(sample); 138 } 139 140 /** 141 * Finds the most frequent label and returns it. 142 * @param outputInfo The output information (must be a subclass of ImmutableLabelInfo). 143 * @return The most frequent label. 144 */ 145 private static Label findMostFrequentLabel(ImmutableOutputInfo<Label> outputInfo) { 146 Label maxLabel = null; 147 long count = -1; 148 149 ImmutableLabelInfo labelInfo = (ImmutableLabelInfo) outputInfo; 150 151 for (Pair<Integer,Label> p : labelInfo) { 152 long curCount = labelInfo.getLabelCount(p.getA()); 153 if (curCount > count) { 154 count = curCount; 155 maxLabel = p.getB(); 156 } 157 } 158 159 return maxLabel; 160 } 161 162 /** 163 * Generates a uniform CDF for the supplied labels. 164 * @param outputInfo The output information. 165 * @return A uniform CDF across the domain. 166 */ 167 private static double[] generateUniformCDF(ImmutableOutputInfo<Label> outputInfo) { 168 int length = outputInfo.getDomain().size(); 169 double[] pmf = Util.generateUniformVector(length,1.0/length); 170 return Util.generateCDF(pmf); 171 } 172 173 /** 174 * Generates a CDF where the label probabilities are proportional to their observed counts. 175 * @param outputInfo The output information. 176 * @return A CDF proportional to the observed counts. 177 */ 178 private static double[] generateStratifiedCDF(ImmutableOutputInfo<Label> outputInfo) { 179 ImmutableLabelInfo labelInfo = (ImmutableLabelInfo) outputInfo; 180 int length = labelInfo.getDomain().size(); 181 long counts = labelInfo.getTotalObservations(); 182 183 double[] pmf = new double[length]; 184 185 for (Pair<Integer,Label> p : labelInfo) { 186 int idx = p.getA(); 187 pmf[idx] = labelInfo.getLabelCount(idx) / (double) counts; 188 } 189 190 return Util.generateCDF(pmf); 191 } 192}