001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.classification.Label;
022import org.tribuo.classification.evaluation.LabelMetric.Context;
023import org.tribuo.evaluation.metrics.MetricTarget;
024
025import java.util.List;
026import java.util.function.ToDoubleBiFunction;
027
028/**
029 * An enum of the default {@link LabelMetric}s supported by the multi-class classification
030 * evaluation package.
031 */
032public enum LabelMetrics {
033
034    /**
035     * The number of true positives.
036     */
037    TP((tgt, ctx) -> ConfusionMetrics.tp(tgt, ctx.getCM())),
038    /**
039     * The number of false positives.
040     */
041    FP((tgt, ctx) -> ConfusionMetrics.fp(tgt, ctx.getCM())),
042    /**
043     * The number of true negatives.
044     */
045    TN((tgt, ctx) -> ConfusionMetrics.tn(tgt, ctx.getCM())),
046    /**
047     * The number of false negatives.
048     */
049    FN((tgt, ctx) -> ConfusionMetrics.fn(tgt, ctx.getCM())),
050    /**
051     * The precision, i.e., the number of true positives divided by the number of predicted positives.
052     */
053    PRECISION((tgt, ctx) -> ConfusionMetrics.precision(tgt, ctx.getCM())),
054    /**
055     * The recall, i.e., the number of true positives divided by the number of ground truth positives.
056     */
057    RECALL((tgt, ctx) -> ConfusionMetrics.recall(tgt, ctx.getCM())),
058    /**
059     * The F_1 score, i.e., the harmonic mean of the precision and the recall.
060     */
061    F1((tgt, ctx) -> ConfusionMetrics.f1(tgt, ctx.getCM())),
062    /**
063     * The accuracy.
064     */
065    ACCURACY((tgt, ctx) -> ConfusionMetrics.accuracy(tgt, ctx.getCM())),
066    /**
067     * The balanced error rate, i.e., the mean of the per class recalls.
068     */
069    BALANCED_ERROR_RATE((tgt, ctx) -> ConfusionMetrics.balancedErrorRate(ctx.getCM())),
070    /**
071     * The area under the receiver-operator curve (ROC).
072     */
073    AUCROC((tgt, ctx) -> LabelMetrics.AUCROC(tgt, ctx.getPredictions())),
074    /**
075     * The averaged precision.
076     */
077    AVERAGED_PRECISION((tgt, ctx) -> LabelMetrics.averagedPrecision(tgt, ctx.getPredictions()));
078
079    private final ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl;
080
081    LabelMetrics(ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl) {
082        this.impl = impl;
083    }
084
085    /**
086     * Returns the implementing function for this metric.
087     * @return The implementing function.
088     */
089    public ToDoubleBiFunction<MetricTarget<Label>, Context> getImpl() {
090        return impl;
091    }
092
093    /**
094     * Gets the LabelMetric wrapped around the supplied MetricTarget.
095     * @param tgt The metric target.
096     * @return The label metric combining the implementation function with the supplied metric target.
097     */
098    public LabelMetric forTarget(MetricTarget<Label> tgt) {
099        return new LabelMetric(tgt, this.name(), this.getImpl());
100    }
101
102    /**
103     * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[])
104     * @param tgt The metric target to use.
105     * @param predictions The predictions to use.
106     * @return The averaged precision for the supplied target with the supplied predictions.
107     */
108    public static double averagedPrecision(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) {
109        if (tgt.getOutputTarget().isPresent()) {
110            return averagedPrecision(tgt.getOutputTarget().get(), predictions);
111        } else {
112            throw new IllegalStateException("Unsupported MetricTarget for averagedPrecision");
113        }
114    }
115
116    /**
117     * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[])
118     * @param label The Label to average across.
119     * @param predictions The predictions to use.
120     * @return The averaged precision for the supplied label with the supplied predictions.
121     */
122    public static double averagedPrecision(Label label, List<Prediction<Label>> predictions) {
123        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
124        return LabelEvaluationUtil.averagedPrecision(record.ypos, record.yscore);
125    }
126
127    /**
128     * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[])
129     * @param label The Label to calculate precision and recall for.
130     * @param predictions The predictions to use.
131     * @return The Precision Recall Curve for the supplied label with the supplied predictions.
132     */
133    public static LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label, List<Prediction<Label>> predictions) {
134        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
135        return LabelEvaluationUtil.generatePRCurve(record.ypos, record.yscore);
136    }
137
138    /**
139     * Area under the ROC curve.
140     *
141     * @param label the label corresponding to the "positive" class
142     * @param predictions the predictions for which we'll compute the score
143     * @return AUC ROC score
144     * @throws UnsupportedOperationException if a prediction with no probability score, which are required to compute the ROC curve. (See also: {@link Model#generatesProbabilities()})
145     */
146    public static double AUCROC(Label label, List<Prediction<Label>> predictions) {
147        PredictionProbabilities record = new PredictionProbabilities(label, predictions);
148        return LabelEvaluationUtil.binaryAUCROC(record.ypos, record.yscore);
149    }
150
151    /**
152     * Area under the ROC curve.
153     *
154     * @param tgt The metric target for the positive class.
155     * @param predictions the predictions for which we'll compute the score
156     * @return AUC ROC score
157     * @throws UnsupportedOperationException if a prediction with no probability score, which are required to compute the ROC curve. (See also: {@link Model#generatesProbabilities()})
158     */
159    public static double AUCROC(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) {
160        if (tgt.getOutputTarget().isPresent()) {
161            return AUCROC(tgt.getOutputTarget().get(), predictions);
162        } else {
163            throw new IllegalStateException("Unsupported MetricTarget for AUCROC");
164        }
165    }
166
167    /**
168     * One day, it'll be a record. Not today mind.
169     */
170    private static final class PredictionProbabilities {
171        final boolean[] ypos;
172        final double[] yscore;
173        PredictionProbabilities(Label label, List<Prediction<Label>> predictions) {
174            int n = predictions.size();
175            ypos = new boolean[n];
176            yscore = new double[n];
177            for (int i = 0; i < n; i++) {
178                Prediction<Label> prediction = predictions.get(i);
179                if (!prediction.hasProbabilities()) {
180                    throw new UnsupportedOperationException(String.format("Invalid prediction at index %d: has no probability score.", i));
181                }
182                if (prediction.getExample().getOutput().equals(label)) {
183                    ypos[i] = true;
184                }
185                yscore[i] = prediction
186                        .getOutputScores()
187                        .get(label.getLabel())
188                        .getScore();
189            }
190        }
191    }
192
193}