001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.ImmutableOutputInfo; 020import org.tribuo.classification.Classifiable; 021 022import java.util.function.ToDoubleFunction; 023 024/** 025 * A confusion matrix for {@link Classifiable}s. 026 * 027 * <p> 028 * We interpret it as follows: 029 * 030 * {@code 031 * C[i, j] = k 032 * } 033 * 034 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times". 035 * 036 * <p> 037 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to 038 * the ground truth. 039 * </p> 040 * @param <T> The type of the output. 041 */ 042public interface ConfusionMatrix<T extends Classifiable<T>> { 043 044 /** 045 * Returns the classification domain that this confusion matrix operates over. 046 * @return The classification domain. 047 */ 048 public ImmutableOutputInfo<T> getDomain(); 049 050 /** 051 * The number of examples this confusion matrix has seen. 052 * @return The number of examples. 053 */ 054 public double support(); 055 056 /** 057 * The number of examples with this true label this confusion matrix has seen. 058 * @param cls The label. 059 * @return The number of examples. 060 */ 061 public double support(T cls); 062 063 /** 064 * The number of true positives for the supplied label. 065 * @param cls The label. 066 * @return The number of examples. 067 */ 068 public double tp(T cls); 069 070 /** 071 * The number of false positives for the supplied label. 072 * @param cls The label. 073 * @return The number of examples. 074 */ 075 public double fp(T cls); 076 077 /** 078 * The number of false negatives for the supplied label. 079 * @param cls The label. 080 * @return The number of examples. 081 */ 082 public double fn(T cls); 083 084 /** 085 * The number of true negatives for the supplied label. 086 * @param cls The label. 087 * @return The number of examples. 088 */ 089 public double tn(T cls); 090 091 /** 092 * The number of times the supplied predicted label was returned for the supplied true class. 093 * @param predictedLabel The predicted label. 094 * @param trueLabel The true label. 095 * @return The number of examples predicted as {@code predictedLabel} when the true label was {@code trueLabel}. 096 */ 097 public double confusion(T predictedLabel, T trueLabel); 098 099 /** 100 * The total number of true positives. 101 * @return The total true positives. 102 */ 103 public default double tp() { 104 return sumOverOutputs(getDomain(), this::tp); 105 } 106 107 /** 108 * The total number of false positives. 109 * @return The total false positives. 110 */ 111 public default double fp() { 112 return sumOverOutputs(getDomain(), this::fp); 113 } 114 115 /** 116 * The total number of false negatives. 117 * @return The total false negatives. 118 */ 119 public default double fn() { 120 return sumOverOutputs(getDomain(), this::fn); 121 } 122 123 /** 124 * The total number of true negatives. 125 * @return The total true negatives. 126 */ 127 public default double tn() { 128 return sumOverOutputs(getDomain(), this::tn); 129 } 130 131 /** 132 * Sums the supplied getter over the domain. 133 * @param domain The domain to sum over. 134 * @param getter The getter to use. 135 * @param <T> The type of the output. 136 * @return The total summed over the domain. 137 */ 138 static <T extends Classifiable<T>> double sumOverOutputs(ImmutableOutputInfo<T> domain, ToDoubleFunction<T> getter) { 139 double total = 0; 140 for (T key : domain.getDomain()) { 141 total += getter.applyAsDouble(key); 142 } 143 return total; 144 } 145}