001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.classification.Classifiable;
020import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
021import org.tribuo.evaluation.metrics.MetricTarget;
022
023import java.util.logging.Logger;
024
025/**
026 * Static functions for computing classification metrics based on a {@link ConfusionMatrix}.
027 */
028public final class ConfusionMetrics {
029
030    private static final Logger logger = Logger.getLogger(ConfusionMetrics.class.getName());
031
032    // singleton
033    private ConfusionMetrics() { }
034
035    /**
036     * Calculates the accuracy given this confusion matrix.
037     *
038     * @param <T>    The type parameter
039     * @param target The metric target
040     * @param cm     The confusion matrix
041     * @return The accuracy
042     */
043    public static <T extends Classifiable<T>> double accuracy(MetricTarget<T> target, ConfusionMatrix<T> cm) {
044        if (target.getOutputTarget().isPresent()) {
045            return accuracy(target.getOutputTarget().get(), cm);
046        } else {
047            return accuracy(target.getAverageTarget().get(), cm);
048        }
049    }
050
051    /**
052     * Calculates a per label accuracy given this confusion matrix.
053     *
054     * @param <T>   The type parameter
055     * @param label The label
056     * @param cm    The confusion matrix
057     * @return The accuracy
058     */
059    public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatrix<T> cm) {
060        double support = cm.support(label);
061        // handle div-by-zero
062        if (support == 0d) {
063            logger.warning("No predictions: accuracy ill-defined");
064            return Double.NaN;
065        }
066        return cm.tp(label) / cm.support(label);
067    }
068
069    /**
070     * Calculates the accuracy using the specified average type and confusion matrix.
071     *
072     * @param <T>     the type parameter
073     * @param average the average
074     * @param cm      The confusion matrix
075     * @return The accuracy
076     */
077    public static <T extends Classifiable<T>> double accuracy(Average average, ConfusionMatrix<T> cm) {
078        if (average.equals(Average.MICRO)) {
079            // handle div-by-zero
080            if (cm.support() == 0d) {
081                logger.warning("No predictions: accuracy ill-defined");
082                return Double.NaN;
083            }
084            return cm.tp() / cm.support();
085        } else {
086            // handle div-by-zero
087            if (cm.getDomain().size() == 0) {
088                logger.warning("Empty domain: accuracy ill-defined");
089                return Double.NaN;
090            }
091            double total = 0d;
092            for (T output : cm.getDomain().getDomain()) {
093                total += accuracy(output, cm);
094            }
095            return total / cm.getDomain().size();
096        }
097    }
098
099    /**
100     * Calculates the balanced error rate, i.e., the mean of the recalls.
101     *
102     * @param <T> the type parameter
103     * @param cm  The confusion matrix
104     * @return the balanced error rate.
105     */
106    public static <T extends Classifiable<T>> double balancedErrorRate(ConfusionMatrix<T> cm) {
107        // handle div-by-zero
108        if (cm.getDomain().size() == 0) {
109            logger.warning("Empty domain: balanced error rate ill-defined");
110            return Double.NaN;
111        }
112        double sr = 0d;
113        for (T output : cm.getDomain().getDomain()) {
114            sr += recall(new MetricTarget<>(output), cm);
115        }
116        return 1d - (sr / cm.getDomain().size());
117    }
118
119    /**
120     * Computes the confusion function value for a given metric target and confusion matrix.
121     * <p>
122     * For example - to compute macro precision:
123     *
124     * <code>
125     * ConfusionFunction&lt;T&gt; fxn = ConfusionMetric::precision;
126     * MetricTarget&lt;T&gt; tgt = new MetricTarget(Average.macro)
127     * ConfusionMatrix&lt;T&gt; cm = ...
128     * compute(fxn, tgt, cm);
129     * </code>
130     * <p>
131     * This is equivalent to the following:
132     *
133     * <code>
134     * ConfusionMatrix&lt;T&gt; cm = ...
135     * double total = 0d;
136     * for (T label : cm.getDomain().getDomain()) {
137     * total += precision(cm.tp(label), cm.tp(label), ...);
138     * }
139     * double avg = total / cm.getDomain().size()
140     * </code>
141     *
142     * @param fxn the confusion function
143     * @param tgt the metric target
144     * @param cm  the confusion matrix
145     * @param <T> the output type
146     * @return the value of fxn applied to (tgt, cm)
147     */
148    private static <T extends Classifiable<T>> double compute(ConfusionFunction<T> fxn, MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
149        return fxn.compute(tgt, cm);
150    }
151
152    /**
153     * Returns the number of true positives, possibly averaged depending on the metric target.
154     *
155     * @param <T> the type parameter
156     * @param tgt The metric target
157     * @param cm  The confusion matrix
158     * @return the true positives.
159     */
160    public static <T extends Classifiable<T>> double tp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
161        return compute(ConfusionMetrics::tp, tgt, cm);
162    }
163
164    /**
165     * Returns the number of false positives, possibly averaged depending on the metric target.
166     *
167     * @param <T> the type parameter
168     * @param tgt The metric target
169     * @param cm  The confusion matrix
170     * @return the false positives.
171     */
172    public static <T extends Classifiable<T>> double fp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
173        return compute(ConfusionMetrics::fp, tgt, cm);
174    }
175
176    /**
177     * Returns the number of true negatives, possibly averaged depending on the metric target.
178     *
179     * @param <T> the type parameter
180     * @param tgt The metric target
181     * @param cm  The confusion matrix
182     * @return the true negatives.
183     */
184    public static <T extends Classifiable<T>> double tn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
185        return compute(ConfusionMetrics::tn, tgt, cm);
186    }
187
188    /**
189     * Returns the number of false negatives, possibly averaged depending on the metric target.
190     *
191     * @param <T> the type parameter
192     * @param tgt The metric target
193     * @param cm  The confusion matrix
194     * @return the false negatives.
195     */
196    public static <T extends Classifiable<T>> double fn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
197        return compute(ConfusionMetrics::fn, tgt, cm);
198    }
199
200    /**
201     * Helper function to return the specified argument. Used as a method reference.
202     * @param tp The true positives.
203     * @param fp The false positives.
204     * @param tn The true negatives.
205     * @param fn The false negatives.
206     * @return The true positives.
207     */
208    private static double tp(double tp, double fp, double tn, double fn) {
209        return tp;
210    }
211
212    /**
213     * Helper function to return the specified argument. Used as a method reference.
214     * @param tp The true positives.
215     * @param fp The false positives.
216     * @param tn The true negatives.
217     * @param fn The false negatives.
218     * @return The false positives.
219     */
220    private static double fp(double tp, double fp, double tn, double fn) {
221        return fp;
222    }
223
224    /**
225     * Helper function to return the specified argument. Used as a method reference.
226     * @param tp The true positives.
227     * @param fp The false positives.
228     * @param tn The true negatives.
229     * @param fn The false negatives.
230     * @return The true negatives.
231     */
232    private static double tn(double tp, double fp, double tn, double fn) {
233        return tn;
234    }
235
236    /**
237     * Helper function to return the specified argument. Used as a method reference.
238     * @param tp The true positives.
239     * @param fp The false positives.
240     * @param tn The true negatives.
241     * @param fn The false negatives.
242     * @return The false negatives.
243     */
244    private static double fn(double tp, double fp, double tn, double fn) {
245        return fn;
246    }
247
248    //
249    // PRECISION ---------------------------------------------------------------
250    //
251
252    /**
253     * Calculates the precision for this metric target.
254     *
255     * @param <T> the type parameter
256     * @param tgt The metric target
257     * @param cm  The confusion matrix
258     * @return the precision.
259     */
260    public static <T extends Classifiable<T>> double precision(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
261        return compute(ConfusionMetrics::precision, tgt, cm);
262    }
263
264    /**
265     * Calculates the precision based upon the supplied statistics.
266     *
267     * @param tp  the true positives
268     * @param fp  the false positives
269     * @param tn  the true negatives
270     * @param fn  the false negatives
271     * @return The recall.
272     */
273    public static double precision(double tp, double fp, double tn, double fn) {
274        double denom = tp + fp;
275        // If the denominator is 0, return 0 (as opposed to Double.NaN, say)
276        return (denom == 0) ? 0d : tp / denom;
277    }
278
279    //
280    // RECALL ------------------------------------------------------------------
281    //
282
283    /**
284     * Calculates the recall for this metric target.
285     *
286     * @param <T> the type parameter
287     * @param tgt The metric target
288     * @param cm  The confusion matrix
289     * @return The recall.
290     */
291    public static <T extends Classifiable<T>> double recall(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
292        return compute(ConfusionMetrics::recall, tgt, cm);
293    }
294
295    /**
296     * Calculates the recall based upon the supplied statistics.
297     *
298     * @param tp  the true positives
299     * @param fp  the false positives
300     * @param tn  the true negatives
301     * @param fn  the false negatives
302     * @return The recall.
303     */
304    public static double recall(double tp, double fp, double tn, double fn) {
305        double denom = tp + fn;
306        // If the denominator is 0, return 0 (as opposed to Double.NaN, say)
307        return (denom == 0) ? 0d : tp / denom;
308    }
309
310    //
311    // F-SCORE -----------------------------------------------------------------
312    //
313
314    /**
315     * Computes the F_1 score.
316     *
317     * @param <T> the type parameter
318     * @param tgt the metric target.
319     * @param cm  the confusion matrix.
320     * @return the F_1 score.
321     */
322    public static <T extends Classifiable<T>> double f1(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
323        return compute(ConfusionMetrics::f1, tgt, cm);
324    }
325
326    /**
327     * Computes the F_1 score.
328     *
329     * @param tp  the true positives
330     * @param fp  the false positives
331     * @param tn  the true negatives
332     * @param fn  the false negatives
333     * @return the F_1 score.
334     */
335    public static double f1(double tp, double fp, double tn, double fn) {
336        return fscore(1d, tp, fp, tn, fn);
337    }
338
339    /**
340     * Computes the Fscore.
341     *
342     * @param beta the beta.
343     * @param tp   the true positives.
344     * @param fp   the false positives.
345     * @param tn   the true negatives.
346     * @param fn   the false negatives.
347     * @return the F_beta score.
348     */
349    public static double fscore(double beta, double tp, double fp, double tn, double fn) {
350        double bsq = beta * beta;
351        double p = precision(tp, fp, tn, fn);
352        double r = recall(tp, fp, tn, fn);
353        double denom = (bsq * p) + r;
354        return (denom == 0) ? 0d : (1 + bsq) * p * r / denom;
355    }
356
357    /**
358     * Computes the Fscore.
359     *
360     * @param <T>  the type parameter
361     * @param tgt  The metric target
362     * @param cm   The confusion matrix
363     * @param beta the beta
364     * @return The F_beta score.
365     */
366    public static <T extends Classifiable<T>> double fscore(MetricTarget<T> tgt, ConfusionMatrix<T> cm, double beta) {
367        ConfusionFunction<T> fxn = (tp, fp, tn, fn) -> fscore(beta, tp, fp, tn, fn);
368        return compute(fxn, tgt, cm);
369    }
370
371    /**
372     * A function that takes a {@link MetricTarget} and {@link ConfusionMatrix} as inputs and outputs the value of
373     * the confusion metric specified in the implementation of
374     * {@link ConfusionFunction#compute(double, double, double, double)}.
375     *
376     * @param <T> The classification type.
377     */
378    @FunctionalInterface
379    private static interface ConfusionFunction<T extends Classifiable<T>> {
380
381        /**
382         * Provides a uniform function signature for a bunch of different metrics.
383         *
384         * @param tp the true positives.
385         * @param fp the false positives.
386         * @param tn the true negatives.
387         * @param fn the false negatives.
388         * @return the value.
389         */
390        double compute(double tp, double fp, double tn, double fn);
391
392        /**
393         * Compute the value.
394         *
395         * @param tgt the metric target.
396         * @param cm  the confusion matrix.
397         * @return the value.
398         */
399        default double compute(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
400            if (tgt.getOutputTarget().isPresent()) {
401                return compute(tgt.getOutputTarget().get(), cm);
402            } else if (tgt.getAverageTarget().isPresent()) {
403                return compute(tgt.getAverageTarget().get(), cm);
404            } else {
405                throw new IllegalStateException("MetricTarget with no actual target");
406            }
407        }
408
409        /**
410         * Compute the value.
411         *
412         * @param label the target label.
413         * @param cm    the confusion matrix.
414         * @return the value.
415         */
416        default double compute(T label, ConfusionMatrix<T> cm) {
417            return compute(cm.tp(label), cm.fp(label), cm.tn(label), cm.fn(label));
418        }
419
420        /**
421         * Compute the value.
422         *
423         * @param average the average type.
424         * @param cm      the confusion matrix.
425         * @return the value.
426         */
427        default double compute(Average average, ConfusionMatrix<T> cm) {
428            switch (average) {
429                case MACRO:
430                    if (cm.getDomain().size() == 0) {
431                        logger.warning("Empty domain: macro-average ill-defined.");
432                        return Double.NaN;
433                    }
434                    double total = 0d;
435                    for (T output : cm.getDomain().getDomain()) {
436                        total += compute(output, cm);
437                    }
438                    return total / cm.getDomain().size();
439                case MICRO:
440                    if (cm.support() == 0) {
441                        logger.warning("No predictions: micro-average ill-defined.");
442                        return Double.NaN;
443                    }
444                    return compute(cm.tp(), cm.fp(), cm.tn(), cm.fn());
445                default:
446                    throw new IllegalArgumentException("Unsupported average type: " + average.name());
447            }
448        }
449    }
450
451}