001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.classification.Classifiable; 020import org.tribuo.evaluation.metrics.EvaluationMetric.Average; 021import org.tribuo.evaluation.metrics.MetricTarget; 022 023import java.util.logging.Logger; 024 025/** 026 * Static functions for computing classification metrics based on a {@link ConfusionMatrix}. 027 */ 028public final class ConfusionMetrics { 029 030 private static final Logger logger = Logger.getLogger(ConfusionMetrics.class.getName()); 031 032 // singleton 033 private ConfusionMetrics() { } 034 035 /** 036 * Calculates the accuracy given this confusion matrix. 037 * 038 * @param <T> The type parameter 039 * @param target The metric target 040 * @param cm The confusion matrix 041 * @return The accuracy 042 */ 043 public static <T extends Classifiable<T>> double accuracy(MetricTarget<T> target, ConfusionMatrix<T> cm) { 044 if (target.getOutputTarget().isPresent()) { 045 return accuracy(target.getOutputTarget().get(), cm); 046 } else { 047 return accuracy(target.getAverageTarget().get(), cm); 048 } 049 } 050 051 /** 052 * Calculates a per label accuracy given this confusion matrix. 053 * 054 * @param <T> The type parameter 055 * @param label The label 056 * @param cm The confusion matrix 057 * @return The accuracy 058 */ 059 public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatrix<T> cm) { 060 double support = cm.support(label); 061 // handle div-by-zero 062 if (support == 0d) { 063 logger.warning("No predictions: accuracy ill-defined"); 064 return Double.NaN; 065 } 066 return cm.tp(label) / cm.support(label); 067 } 068 069 /** 070 * Calculates the accuracy using the specified average type and confusion matrix. 071 * 072 * @param <T> the type parameter 073 * @param average the average 074 * @param cm The confusion matrix 075 * @return The accuracy 076 */ 077 public static <T extends Classifiable<T>> double accuracy(Average average, ConfusionMatrix<T> cm) { 078 if (average.equals(Average.MICRO)) { 079 // handle div-by-zero 080 if (cm.support() == 0d) { 081 logger.warning("No predictions: accuracy ill-defined"); 082 return Double.NaN; 083 } 084 return cm.tp() / cm.support(); 085 } else { 086 // handle div-by-zero 087 if (cm.getDomain().size() == 0) { 088 logger.warning("Empty domain: accuracy ill-defined"); 089 return Double.NaN; 090 } 091 double total = 0d; 092 for (T output : cm.getDomain().getDomain()) { 093 total += accuracy(output, cm); 094 } 095 return total / cm.getDomain().size(); 096 } 097 } 098 099 /** 100 * Calculates the balanced error rate, i.e., the mean of the recalls. 101 * 102 * @param <T> the type parameter 103 * @param cm The confusion matrix 104 * @return the balanced error rate. 105 */ 106 public static <T extends Classifiable<T>> double balancedErrorRate(ConfusionMatrix<T> cm) { 107 // handle div-by-zero 108 if (cm.getDomain().size() == 0) { 109 logger.warning("Empty domain: balanced error rate ill-defined"); 110 return Double.NaN; 111 } 112 double sr = 0d; 113 for (T output : cm.getDomain().getDomain()) { 114 sr += recall(new MetricTarget<>(output), cm); 115 } 116 return 1d - (sr / cm.getDomain().size()); 117 } 118 119 /** 120 * Computes the confusion function value for a given metric target and confusion matrix. 121 * <p> 122 * For example - to compute macro precision: 123 * 124 * <code> 125 * ConfusionFunction<T> fxn = ConfusionMetric::precision; 126 * MetricTarget<T> tgt = new MetricTarget(Average.macro) 127 * ConfusionMatrix<T> cm = ... 128 * compute(fxn, tgt, cm); 129 * </code> 130 * <p> 131 * This is equivalent to the following: 132 * 133 * <code> 134 * ConfusionMatrix<T> cm = ... 135 * double total = 0d; 136 * for (T label : cm.getDomain().getDomain()) { 137 * total += precision(cm.tp(label), cm.tp(label), ...); 138 * } 139 * double avg = total / cm.getDomain().size() 140 * </code> 141 * 142 * @param fxn the confusion function 143 * @param tgt the metric target 144 * @param cm the confusion matrix 145 * @param <T> the output type 146 * @return the value of fxn applied to (tgt, cm) 147 */ 148 private static <T extends Classifiable<T>> double compute(ConfusionFunction<T> fxn, MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 149 return fxn.compute(tgt, cm); 150 } 151 152 /** 153 * Returns the number of true positives, possibly averaged depending on the metric target. 154 * 155 * @param <T> the type parameter 156 * @param tgt The metric target 157 * @param cm The confusion matrix 158 * @return the true positives. 159 */ 160 public static <T extends Classifiable<T>> double tp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 161 return compute(ConfusionMetrics::tp, tgt, cm); 162 } 163 164 /** 165 * Returns the number of false positives, possibly averaged depending on the metric target. 166 * 167 * @param <T> the type parameter 168 * @param tgt The metric target 169 * @param cm The confusion matrix 170 * @return the false positives. 171 */ 172 public static <T extends Classifiable<T>> double fp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 173 return compute(ConfusionMetrics::fp, tgt, cm); 174 } 175 176 /** 177 * Returns the number of true negatives, possibly averaged depending on the metric target. 178 * 179 * @param <T> the type parameter 180 * @param tgt The metric target 181 * @param cm The confusion matrix 182 * @return the true negatives. 183 */ 184 public static <T extends Classifiable<T>> double tn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 185 return compute(ConfusionMetrics::tn, tgt, cm); 186 } 187 188 /** 189 * Returns the number of false negatives, possibly averaged depending on the metric target. 190 * 191 * @param <T> the type parameter 192 * @param tgt The metric target 193 * @param cm The confusion matrix 194 * @return the false negatives. 195 */ 196 public static <T extends Classifiable<T>> double fn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 197 return compute(ConfusionMetrics::fn, tgt, cm); 198 } 199 200 /** 201 * Helper function to return the specified argument. Used as a method reference. 202 * @param tp The true positives. 203 * @param fp The false positives. 204 * @param tn The true negatives. 205 * @param fn The false negatives. 206 * @return The true positives. 207 */ 208 private static double tp(double tp, double fp, double tn, double fn) { 209 return tp; 210 } 211 212 /** 213 * Helper function to return the specified argument. Used as a method reference. 214 * @param tp The true positives. 215 * @param fp The false positives. 216 * @param tn The true negatives. 217 * @param fn The false negatives. 218 * @return The false positives. 219 */ 220 private static double fp(double tp, double fp, double tn, double fn) { 221 return fp; 222 } 223 224 /** 225 * Helper function to return the specified argument. Used as a method reference. 226 * @param tp The true positives. 227 * @param fp The false positives. 228 * @param tn The true negatives. 229 * @param fn The false negatives. 230 * @return The true negatives. 231 */ 232 private static double tn(double tp, double fp, double tn, double fn) { 233 return tn; 234 } 235 236 /** 237 * Helper function to return the specified argument. Used as a method reference. 238 * @param tp The true positives. 239 * @param fp The false positives. 240 * @param tn The true negatives. 241 * @param fn The false negatives. 242 * @return The false negatives. 243 */ 244 private static double fn(double tp, double fp, double tn, double fn) { 245 return fn; 246 } 247 248 // 249 // PRECISION --------------------------------------------------------------- 250 // 251 252 /** 253 * Calculates the precision for this metric target. 254 * 255 * @param <T> the type parameter 256 * @param tgt The metric target 257 * @param cm The confusion matrix 258 * @return the precision. 259 */ 260 public static <T extends Classifiable<T>> double precision(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 261 return compute(ConfusionMetrics::precision, tgt, cm); 262 } 263 264 /** 265 * Calculates the precision based upon the supplied statistics. 266 * 267 * @param tp the true positives 268 * @param fp the false positives 269 * @param tn the true negatives 270 * @param fn the false negatives 271 * @return The recall. 272 */ 273 public static double precision(double tp, double fp, double tn, double fn) { 274 double denom = tp + fp; 275 // If the denominator is 0, return 0 (as opposed to Double.NaN, say) 276 return (denom == 0) ? 0d : tp / denom; 277 } 278 279 // 280 // RECALL ------------------------------------------------------------------ 281 // 282 283 /** 284 * Calculates the recall for this metric target. 285 * 286 * @param <T> the type parameter 287 * @param tgt The metric target 288 * @param cm The confusion matrix 289 * @return The recall. 290 */ 291 public static <T extends Classifiable<T>> double recall(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 292 return compute(ConfusionMetrics::recall, tgt, cm); 293 } 294 295 /** 296 * Calculates the recall based upon the supplied statistics. 297 * 298 * @param tp the true positives 299 * @param fp the false positives 300 * @param tn the true negatives 301 * @param fn the false negatives 302 * @return The recall. 303 */ 304 public static double recall(double tp, double fp, double tn, double fn) { 305 double denom = tp + fn; 306 // If the denominator is 0, return 0 (as opposed to Double.NaN, say) 307 return (denom == 0) ? 0d : tp / denom; 308 } 309 310 // 311 // F-SCORE ----------------------------------------------------------------- 312 // 313 314 /** 315 * Computes the F_1 score. 316 * 317 * @param <T> the type parameter 318 * @param tgt the metric target. 319 * @param cm the confusion matrix. 320 * @return the F_1 score. 321 */ 322 public static <T extends Classifiable<T>> double f1(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 323 return compute(ConfusionMetrics::f1, tgt, cm); 324 } 325 326 /** 327 * Computes the F_1 score. 328 * 329 * @param tp the true positives 330 * @param fp the false positives 331 * @param tn the true negatives 332 * @param fn the false negatives 333 * @return the F_1 score. 334 */ 335 public static double f1(double tp, double fp, double tn, double fn) { 336 return fscore(1d, tp, fp, tn, fn); 337 } 338 339 /** 340 * Computes the Fscore. 341 * 342 * @param beta the beta. 343 * @param tp the true positives. 344 * @param fp the false positives. 345 * @param tn the true negatives. 346 * @param fn the false negatives. 347 * @return the F_beta score. 348 */ 349 public static double fscore(double beta, double tp, double fp, double tn, double fn) { 350 double bsq = beta * beta; 351 double p = precision(tp, fp, tn, fn); 352 double r = recall(tp, fp, tn, fn); 353 double denom = (bsq * p) + r; 354 return (denom == 0) ? 0d : (1 + bsq) * p * r / denom; 355 } 356 357 /** 358 * Computes the Fscore. 359 * 360 * @param <T> the type parameter 361 * @param tgt The metric target 362 * @param cm The confusion matrix 363 * @param beta the beta 364 * @return The F_beta score. 365 */ 366 public static <T extends Classifiable<T>> double fscore(MetricTarget<T> tgt, ConfusionMatrix<T> cm, double beta) { 367 ConfusionFunction<T> fxn = (tp, fp, tn, fn) -> fscore(beta, tp, fp, tn, fn); 368 return compute(fxn, tgt, cm); 369 } 370 371 /** 372 * A function that takes a {@link MetricTarget} and {@link ConfusionMatrix} as inputs and outputs the value of 373 * the confusion metric specified in the implementation of 374 * {@link ConfusionFunction#compute(double, double, double, double)}. 375 * 376 * @param <T> The classification type. 377 */ 378 @FunctionalInterface 379 private static interface ConfusionFunction<T extends Classifiable<T>> { 380 381 /** 382 * Provides a uniform function signature for a bunch of different metrics. 383 * 384 * @param tp the true positives. 385 * @param fp the false positives. 386 * @param tn the true negatives. 387 * @param fn the false negatives. 388 * @return the value. 389 */ 390 double compute(double tp, double fp, double tn, double fn); 391 392 /** 393 * Compute the value. 394 * 395 * @param tgt the metric target. 396 * @param cm the confusion matrix. 397 * @return the value. 398 */ 399 default double compute(MetricTarget<T> tgt, ConfusionMatrix<T> cm) { 400 if (tgt.getOutputTarget().isPresent()) { 401 return compute(tgt.getOutputTarget().get(), cm); 402 } else if (tgt.getAverageTarget().isPresent()) { 403 return compute(tgt.getAverageTarget().get(), cm); 404 } else { 405 throw new IllegalStateException("MetricTarget with no actual target"); 406 } 407 } 408 409 /** 410 * Compute the value. 411 * 412 * @param label the target label. 413 * @param cm the confusion matrix. 414 * @return the value. 415 */ 416 default double compute(T label, ConfusionMatrix<T> cm) { 417 return compute(cm.tp(label), cm.fp(label), cm.tn(label), cm.fn(label)); 418 } 419 420 /** 421 * Compute the value. 422 * 423 * @param average the average type. 424 * @param cm the confusion matrix. 425 * @return the value. 426 */ 427 default double compute(Average average, ConfusionMatrix<T> cm) { 428 switch (average) { 429 case MACRO: 430 if (cm.getDomain().size() == 0) { 431 logger.warning("Empty domain: macro-average ill-defined."); 432 return Double.NaN; 433 } 434 double total = 0d; 435 for (T output : cm.getDomain().getDomain()) { 436 total += compute(output, cm); 437 } 438 return total / cm.getDomain().size(); 439 case MICRO: 440 if (cm.support() == 0) { 441 logger.warning("No predictions: micro-average ill-defined."); 442 return Double.NaN; 443 } 444 return compute(cm.tp(), cm.fp(), cm.tn(), cm.fn()); 445 default: 446 throw new IllegalArgumentException("Unsupported average type: " + average.name()); 447 } 448 } 449 } 450 451}