001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.ImmutableOutputInfo; 020import org.tribuo.Model; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.math.la.DenseMatrix; 024 025import java.util.ArrayList; 026import java.util.HashMap; 027import java.util.HashSet; 028import java.util.LinkedHashSet; 029import java.util.List; 030import java.util.Map; 031import java.util.Set; 032import java.util.function.ToDoubleFunction; 033import java.util.logging.Logger; 034 035/** 036 * A confusion matrix for {@link Label}s. 037 * <p> 038 * We interpret it as follows: 039 * 040 * {@code 041 * C[i, j] = k 042 * } 043 * 044 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times". 045 * 046 * <p> 047 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to 048 * the ground truth. 049 * </p> 050 */ 051public final class LabelConfusionMatrix implements ConfusionMatrix<Label> { 052 053 private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName()); 054 055 private final ImmutableOutputInfo<Label> domain; 056 057 private final int total; 058 private final Map<Label, Double> occurrences; 059 060 private final Set<Label> observed; 061 062 private final DenseMatrix cm; 063 064 private List<Label> labelOrder; 065 066 /** 067 * Creates a confusion matrix from the supplied predictions, using the label info 068 * from the supplied model. 069 * @param model The model to use for the label information. 070 * @param predictions The predictions. 071 */ 072 public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) { 073 this(model.getOutputIDInfo(), predictions); 074 } 075 076 /** 077 * Creates a confusion matrix from the supplied predictions and label info. 078 * @throws IllegalArgumentException If the domain doesn't contain all the predictions. 079 * @param domain The label information. 080 * @param predictions The predictions. 081 */ 082 public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) { 083 this.domain = domain; 084 this.total = predictions.size(); 085 this.cm = new DenseMatrix(domain.size(), domain.size()); 086 this.occurrences = new HashMap<>(); 087 this.observed = new HashSet<>(); 088 tabulate(predictions); 089 } 090 091 /** 092 * Aggregate the predictions into this confusion matrix. 093 * @param predictions The predictions to aggregate. 094 */ 095 private void tabulate(List<Prediction<Label>> predictions) { 096 predictions.forEach(prediction -> { 097 Label y = prediction.getExample().getOutput(); 098 Label p = prediction.getOutput(); 099 // 100 // Check that the ground truth label is valid 101 if (y.getLabel().equals(Label.UNKNOWN)) { 102 throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate."); 103 } 104 occurrences.merge(y,1d, Double::sum); 105 observed.add(y); 106 observed.add(p); 107 int iy = getIDOrThrow(y); 108 int ip = getIDOrThrow(p); 109 cm.add(ip, iy, 1d); 110 }); 111 } 112 113 @Override 114 public ImmutableOutputInfo<Label> getDomain() { 115 return domain; 116 } 117 118 @Override 119 public double support() { 120 return total; 121 } 122 123 @Override 124 public double support(Label label) { 125 return occurrences.getOrDefault(label, 0d); 126 } 127 128 @Override 129 public double tp(Label cls) { 130 return compute(cls, (i) -> cm.get(i, i)); 131 } 132 133 @Override 134 public double fp(Label cls) { 135 // Row-wise sum less true positives 136 return compute(cls, i -> cm.rowSum(i) - cm.get(i, i)); 137 } 138 139 @Override 140 public double fn(Label cls) { 141 // Column-wise sum less true positives 142 return compute(cls, i -> cm.columnSum(i) - cm.get(i, i)); 143 } 144 145 @Override 146 public double tn(Label cls) { 147 int n = getDomain().size(); 148 int i = getDomain().getID(cls); 149 double total = 0d; 150 for (int j = 0; j < n; j++) { 151 if (j == i) { 152 continue; 153 } 154 for (int k = 0; k < n; k++) { 155 if (k == i) { 156 continue; 157 } 158 total += cm.get(j, k); 159 } 160 } 161 return total; 162 } 163 164 @Override 165 public double confusion(Label predicted, Label trueClass) { 166 int i = getDomain().getID(predicted); 167 int j = getDomain().getID(trueClass); 168 return cm.get(i, j); 169 } 170 171 /** 172 * A convenience method for extracting the appropriate label statistic. 173 * @param cls The label to check. 174 * @param getter The get function which accepts a label id. 175 * @return The statistic for that label id. 176 */ 177 private double compute(Label cls, ToDoubleFunction<Integer> getter) { 178 int i = getDomain().getID(cls); 179 if (i < 0) { 180 logger.fine("Unknown Label " + cls); 181 return 0d; 182 } 183 return getter.applyAsDouble(i); 184 } 185 186 /** 187 * Gets the id for the supplied label, or throws an {@link IllegalArgumentException} if it's 188 * an unknown label. 189 * @param key The label. 190 * @return The int id for that label. 191 */ 192 private int getIDOrThrow(Label key) { 193 int id = domain.getID(key); 194 if (id < 0) { 195 throw new IllegalArgumentException("Unknown label: " + key); 196 } 197 return id; 198 } 199 200 /** 201 * Sets the label order used in {@link #toString}. 202 * @param labelOrder The label order to use. 203 */ 204 public void setLabelOrder(List<Label> labelOrder) { 205 this.labelOrder = labelOrder; 206 } 207 208 @Override 209 public String toString() { 210 if (labelOrder == null) { 211 labelOrder = new ArrayList<>(domain.getDomain()); 212 } 213 labelOrder.retainAll(observed); 214 215 int maxLen = Integer.MIN_VALUE; 216 for (Label label : labelOrder) { 217 maxLen = Math.max(label.getLabel().length(), maxLen); 218 maxLen = Math.max(String.format(" %,d", (int)(double)occurrences.getOrDefault(label,0.0)).length(), maxLen); 219 } 220 221 StringBuilder sb = new StringBuilder(); 222 String trueLabelFormat = String.format("%%-%ds", maxLen + 2); 223 String predictedLabelFormat = String.format("%%%ds", maxLen + 2); 224 String countFormat = String.format("%%,%dd", maxLen + 2); 225 226 // 227 // Empty spot in first row for labels on subsequent rows. 228 sb.append(String.format(trueLabelFormat, "")); 229 230 // 231 // Labels across the top for predicted. 232 for (Label predictedLabel : labelOrder) { 233 sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel())); 234 } 235 sb.append('\n'); 236 237 for (Label trueLabel : labelOrder) { 238 sb.append(String.format(trueLabelFormat, trueLabel.getLabel())); 239 for (Label predictedLabel : labelOrder) { 240 int confusion = (int) confusion(predictedLabel, trueLabel); 241 sb.append(String.format(countFormat, confusion)); 242 } 243 sb.append('\n'); 244 } 245 return sb.toString(); 246 } 247 248 /** 249 * Emits a HTML table representation of the Confusion Matrix. 250 * @return The confusion matrix as a HTML table. 251 */ 252 public String toHTML() { 253 if (labelOrder == null) { 254 labelOrder = new ArrayList<>(domain.getDomain()); 255 } 256 Set<Label> labelsToPrint = new LinkedHashSet<>(labelOrder); 257 labelsToPrint.retainAll(observed); 258 StringBuilder sb = new StringBuilder(); 259 sb.append("<table>\n"); 260 sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", occurrences.size() + 1)); 261 sb.append("<tr><th></th>"); 262 for (Label predictedLabel : labelsToPrint) { 263 sb.append("<th style=\"text-align:right\">") 264 .append(predictedLabel) 265 .append("</th>"); 266 } 267 sb.append("<th style=\"text-align:right\">Total</th>"); 268 sb.append("</tr>\n"); 269 for (Label trueLabel : labelsToPrint) { 270 sb.append("<tr><th>").append(trueLabel).append("</th>"); 271 double count = occurrences.getOrDefault(trueLabel, 0d); 272 for (Label predictedLabel : labelsToPrint) { 273 double tlmc = confusion(predictedLabel,trueLabel); 274 double percent = (tlmc / count) * 100; 275 sb.append("<td style=\"text-align:right\">") 276 .append(String.format("%,d (%.1f%%)", (int)tlmc, percent)) 277 .append("</td>"); 278 } 279 sb.append("<td style=\"text-align:right\">").append(count).append("</td>"); 280 sb.append("</tr>\n"); 281 } 282 sb.append("</table>"); 283 return sb.toString(); 284 } 285}