001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.evaluation; 018 019import org.tribuo.ImmutableOutputInfo; 020import org.tribuo.Model; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.classification.evaluation.ConfusionMatrix; 024import org.tribuo.math.la.DenseMatrix; 025import org.tribuo.multilabel.MultiLabel; 026import org.tribuo.multilabel.MultiLabelFactory; 027 028import java.util.List; 029import java.util.Set; 030import java.util.function.Function; 031 032/** 033 * A {@link ConfusionMatrix} which accepts {@link MultiLabel}s. 034 * 035 * <p> 036 * In a multi-label confusion matrix M, 037 * <pre> 038 * tn = M[:, 0, 0] 039 * fn = M[:, 1, 0] 040 * tp = M[:, 1, 1] 041 * fp = M[:, 0, 1] 042 * </pre> 043 * <p> 044 * For class-wise values, 045 * <pre> 046 * tn(class i) = M[i, 0, 0] 047 * fn(class i) = M[i, 1, 0] 048 * tp(class i) = M[i, 1, 1] 049 * fp(class i) = M[i, 0, 1] 050 * </pre> 051 */ 052public final class MultiLabelConfusionMatrix implements ConfusionMatrix<MultiLabel> { 053 054 private final ImmutableOutputInfo<MultiLabel> domain; 055 private final DenseMatrix[] mcm; 056 private final DenseMatrix confusion; 057 058 public MultiLabelConfusionMatrix(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) { 059 this(model.getOutputIDInfo(), predictions); 060 } 061 062 MultiLabelConfusionMatrix(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) { 063 this.domain = domain; 064 ConfusionMatrixTuple tab = tabulate(domain, predictions); 065 this.mcm = tab.mcm; 066 this.confusion = tab.confusion; 067 } 068 069 @Override 070 public double support(MultiLabel cls) { 071 double total = 0d; 072 for (Label label : cls.getLabelSet()) { 073 int ix = getDomain().getID(new MultiLabel(label)); 074 /* 075 mcm[i] = 076 [tn, fn] 077 [fp, tp] 078 079 support = false negatives + true positives 080 081 false neg => ground truth was [label] but we predicted something else 082 true pos => ground truth was [label] and we predicted [label] 083 084 (whereas: false pos => ground truth was NOT [label] but we predicted [label]) 085 086 so 087 088 support = false neg + true pos = mcm[i, 0, 1] + mcm[i, 1, 1] = mcm[i, :, 1].sum() 089 */ 090 total += mcm[ix].getColumn(1).sum(); 091 } 092 return total; 093 } 094 095 @Override 096 public ImmutableOutputInfo<MultiLabel> getDomain() { 097 return domain; 098 } 099 100 @Override 101 public double support() { 102 double total = 0d; 103 for (int i = 0; i < domain.size(); i++) { 104 total += mcm[i].getColumn(1).sum(); 105 } 106 return total; 107 } 108 109 @Override 110 public double tp(MultiLabel cls) { 111 return compute(cls, (cm) -> cm.get(1, 1)); 112 } 113 114 @Override 115 public double fp(MultiLabel cls) { 116 return compute(cls, (cm) -> cm.get(0, 1)); 117 } 118 119 @Override 120 public double fn(MultiLabel cls) { 121 return compute(cls, (cm) -> cm.get(1, 0)); 122 } 123 124 @Override 125 public double tn(MultiLabel cls) { 126 return compute(cls, (cm) -> cm.get(0, 0)); 127 } 128 129 private double compute(MultiLabel cls, Function<DenseMatrix, Double> getter) { 130 double total = 0d; 131 for (Label label : cls.getLabelSet()) { 132 int i = domain.getID(new MultiLabel(label.getLabel())); 133 // 134 // When input class is not in the domain, ID will be -1. 135 if (i < 0) { 136 continue; 137 } 138 DenseMatrix cm = mcm[i]; 139 total += getter.apply(cm); 140 } 141 return total; 142 } 143 144 @Override 145 public double confusion(MultiLabel predicted, MultiLabel truth) { 146 double total = 0d; 147 Set<Label> trueSet = truth.getLabelSet(); 148 Set<Label> predSet = predicted.getLabelSet(); 149 for (Label predLabel : predSet) { 150 int idx = domain.getID(new MultiLabel(predLabel.getLabel())); 151 for (Label trueLabel : trueSet) { 152 int jdx = domain.getID(new MultiLabel(trueLabel.getLabel())); 153 total += this.confusion.get(idx, jdx); 154 } 155 } 156 return total; 157 } 158 159 @Override 160 public String toString() { 161 StringBuilder sb = new StringBuilder(); 162 sb.append("["); 163 for (int i = 0; i < mcm.length; i++) { 164 DenseMatrix cm = mcm[i]; 165 sb.append(cm.toString()); 166 sb.append("\n"); 167 } 168 sb.append("]"); 169 return sb.toString(); 170 } 171 172 static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) { 173 // this just keeps track of how many times [class x] was predicted to be [class y] 174 DenseMatrix confusion = new DenseMatrix(domain.size(), domain.size()); 175 176 DenseMatrix[] mcm = new DenseMatrix[domain.size()]; 177 for (int i = 0; i < domain.size(); i++) { 178 mcm[i] = new DenseMatrix(2, 2); 179 } 180 181 int predIndex = 0; 182 for (Prediction<MultiLabel> prediction : predictions) { 183 MultiLabel predictedOutput = prediction.getOutput(); 184 MultiLabel trueOutput = prediction.getExample().getOutput(); 185 if (trueOutput.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) { 186 throw new IllegalArgumentException("The sentinel Unknown MultiLabel was used as a ground truth label at prediction number " + predIndex); 187 } else if (predictedOutput.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) { 188 throw new IllegalArgumentException("The sentinel Unknown MultiLabel was predicted by the model at prediction number " + predIndex); 189 } 190 191 Set<Label> trueSet = trueOutput.getLabelSet(); 192 Set<Label> predSet = predictedOutput.getLabelSet(); 193 194 // 195 // Count true positives and false positives 196 for (Label pred : predSet) { 197 int idx = domain.getID(new MultiLabel(pred.getLabel())); 198 if (trueSet.contains(pred)) { 199 // 200 // true positive: mcm[i, 1, 1]++ 201 mcm[idx].add(1, 1, 1d); 202 } else { 203 // 204 // false positive: mcm[i, 1, 0]++ 205 mcm[idx].add(1, 0, 1d); 206 } 207 } 208 209 // 210 // Count false negatives and populate the confusion table 211 for (Label trueLabel : trueSet) { 212 int idx = domain.getID(new MultiLabel(trueLabel.getLabel())); 213 if (idx < 0) { 214 throw new IllegalArgumentException("Unknown label '" + trueLabel.getLabel() + "' found in the ground truth labels at prediction number " + predIndex 215 + ", this label is not known by the model which made the predictions."); 216 } 217 218 // 219 // Doing two things in this loop: 220 // 1) Checking if predSet contains trueLabel 221 // 2) Counting the # of times [trueLabel] was predicted to be [predLabel] to populate the confusion table 222 boolean found = false; 223 for (Label predLabel : predSet) { 224 int jdx = domain.getID(new MultiLabel(predLabel.getLabel())); 225 confusion.add(jdx, idx, 1d); 226 227 if (predLabel.equals(trueLabel)) { 228 found = true; 229 } 230 } 231 232 if (!found) { 233 // 234 // false negative: mcm[i, 0, 1]++ 235 mcm[idx].add(0, 1, 1d); 236 } 237 // else { true positive: already counted } 238 } 239 240 // 241 // True negatives everywhere else 242 for (MultiLabel multilabel : domain.getDomain()) { 243 Set<Label> labels = multilabel.getLabelSet(); 244 for (Label label : labels) { 245 if (!trueSet.contains(label) && !predSet.contains(label)) { 246 int ix = domain.getID(new MultiLabel(label)); 247 mcm[ix].add(0, 0, 1d); 248 } 249 } 250 } 251 predIndex++; 252 } 253 254 return new ConfusionMatrixTuple(mcm, confusion); 255 } 256 257 /** 258 * It's a record, ooops not yet, we don't require Java 14. 259 */ 260 static final class ConfusionMatrixTuple { 261 final DenseMatrix[] mcm; 262 final DenseMatrix confusion; 263 ConfusionMatrixTuple(DenseMatrix[] mcm, DenseMatrix confusion) { 264 this.mcm = mcm; 265 this.confusion = confusion; 266 } 267 268 DenseMatrix[] getMCM() { 269 return mcm; 270 } 271 } 272}