001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.ImmutableOutputInfo;
020import org.tribuo.classification.Classifiable;
021
022import java.util.function.ToDoubleFunction;
023
024/**
025 * A confusion matrix for {@link Classifiable}s.
026 *
027 * <p>
028 * We interpret it as follows:
029 *
030 * {@code
031 * C[i, j] = k
032 * }
033 *
034 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times".
035 *
036 * <p>
037 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to
038 * the ground truth.
039 * </p>
040 * @param <T> The type of the output.
041 */
042public interface ConfusionMatrix<T extends Classifiable<T>> {
043
044    /**
045     * Returns the classification domain that this confusion matrix operates over.
046     * @return The classification domain.
047     */
048    public ImmutableOutputInfo<T> getDomain();
049
050    /**
051     * The number of examples this confusion matrix has seen.
052     * @return The number of examples.
053     */
054    public double support();
055
056    /**
057     * The number of examples with this true label this confusion matrix has seen.
058     * @param cls The label.
059     * @return The number of examples.
060     */
061    public double support(T cls);
062
063    /**
064     * The number of true positives for the supplied label.
065     * @param cls The label.
066     * @return The number of examples.
067     */
068    public double tp(T cls);
069
070    /**
071     * The number of false positives for the supplied label.
072     * @param cls The label.
073     * @return The number of examples.
074     */
075    public double fp(T cls);
076
077    /**
078     * The number of false negatives for the supplied label.
079     * @param cls The label.
080     * @return The number of examples.
081     */
082    public double fn(T cls);
083
084    /**
085     * The number of true negatives for the supplied label.
086     * @param cls The label.
087     * @return The number of examples.
088     */
089    public double tn(T cls);
090
091    /**
092     * The number of times the supplied predicted label was returned for the supplied true class.
093     * @param predictedLabel The predicted label.
094     * @param trueLabel The true label.
095     * @return The number of examples predicted as {@code predictedLabel} when the true label was {@code trueLabel}.
096     */
097    public double confusion(T predictedLabel, T trueLabel);
098
099    /**
100     * The total number of true positives.
101     * @return The total true positives.
102     */
103    public default double tp() {
104        return sumOverOutputs(getDomain(), this::tp);
105    }
106
107    /**
108     * The total number of false positives.
109     * @return The total false positives.
110     */
111    public default double fp() {
112        return sumOverOutputs(getDomain(), this::fp);
113    }
114
115    /**
116     * The total number of false negatives.
117     * @return The total false negatives.
118     */
119    public default double fn() {
120        return sumOverOutputs(getDomain(), this::fn);
121    }
122
123    /**
124     * The total number of true negatives.
125     * @return The total true negatives.
126     */
127    public default double tn() {
128        return sumOverOutputs(getDomain(), this::tn);
129    }
130
131    /**
132     * Sums the supplied getter over the domain.
133     * @param domain The domain to sum over.
134     * @param getter The getter to use.
135     * @param <T> The type of the output.
136     * @return The total summed over the domain.
137     */
138    static <T extends Classifiable<T>> double sumOverOutputs(ImmutableOutputInfo<T> domain, ToDoubleFunction<T> getter) {
139        double total = 0;
140        for (T key : domain.getDomain()) {
141            total += getter.applyAsDouble(key);
142        }
143        return total;
144    }
145}