001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.classification.Label;
020import org.tribuo.evaluation.EvaluationRenderer;
021
022import java.util.ArrayList;
023import java.util.List;
024
025/**
026 * Adds multi-class classification specific metrics to {@link ClassifierEvaluation}.
027 */
028public interface LabelEvaluation extends ClassifierEvaluation<Label> {
029
030    /**
031     * The overall accuracy of the evaluation.
032     * @return The accuracy.
033     */
034    double accuracy();
035
036    /**
037     * The per label accuracy of the evaluation.
038     * @param label The target label.
039     * @return The per label accuracy.
040     */
041    double accuracy(Label label);
042
043    /**
044     * Area under the ROC curve.
045     *
046     * @param label target label
047     * @return AUC ROC score
048     *
049     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
050     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
051     */
052    double AUCROC(Label label);
053
054    /**
055     * Area under the ROC curve averaged across labels.
056     * <p>
057     * If {@code weighted} is false, use a macro average, if true, weight by the evaluation's observed class counts.
058     * </p>
059     *
060     * @param weighted If true weight by the class counts, if false use a macro average.
061     * @return The average AUCROC.
062     *
063     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
064     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
065     */
066    double averageAUCROC(boolean weighted);
067
068    /**
069     * Summarises a Precision-Recall Curve by taking the weighted mean of the
070     * precisions at a given threshold, where the weight is the recall achieved at
071     * that threshold.
072     *
073     * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[])
074     *
075     * @param label The target label.
076     * @return The averaged precision for that label.
077     *
078     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
079     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
080     */
081    double averagedPrecision(Label label);
082
083    /**
084     * Calculates the Precision Recall curve for a single label.
085     *
086     * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[])
087     *
088     * @param label The target label.
089     * @return The precision recall curve for that label.
090     *
091     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
092     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
093     */
094    LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label);
095
096    /**
097     * Returns a HTML formatted String representing this evaluation.
098     * @return A HTML formatted String.
099     */
100    default String toHTML() {
101        return LabelEvaluation.toHTML(this);
102    }
103
104    /**
105     * This method produces a nicely formatted String output, with
106     * appropriate tabs and newlines, suitable for display on a terminal.
107     * It can be used as an implementation of the {@link EvaluationRenderer}
108     * functional interface.
109     * @param evaluation The evaluation to format.
110     * @return Formatted output showing the main results of the evaluation.
111     */
112    public static String toFormattedString(LabelEvaluation evaluation) {
113        ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
114        List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
115        StringBuilder sb = new StringBuilder();
116        int tp = 0;
117        int fn = 0;
118        int fp = 0;
119        int n = 0;
120        //
121        // Figure out the biggest class label and therefore the format string
122        // that we should use for them.
123        int maxLabelSize = "Balanced Error Rate".length();
124        for(Label label : labelOrder) {
125            maxLabelSize = Math.max(maxLabelSize, label.getLabel().length());
126        }
127        String labelFormatString = String.format("%%-%ds", maxLabelSize+2);
128        sb.append(String.format(labelFormatString, "Class"));
129        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
130        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
131        for (Label label : labelOrder) {
132            if (cm.support(label) == 0) {
133                continue;
134            }
135            n += cm.support(label);
136            tp += cm.tp(label);
137            fn += cm.fn(label);
138            fp += cm.fp(label);
139            sb.append(String.format(labelFormatString, label));
140            sb.append(String.format("%,12d%,12d%,12d%,12d",
141                    (int) cm.support(label),
142                    (int) cm.tp(label),
143                    (int) cm.fn(label),
144                    (int) cm.fp(label)
145            ));
146            sb.append(String.format("%12.3f%12.3f%12.3f%n",
147                    evaluation.recall(label),
148                    evaluation.precision(label),
149                    evaluation.f1(label)));
150        }
151        sb.append(String.format(labelFormatString, "Total"));
152        sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
153        sb.append(String.format(labelFormatString, "Accuracy"));
154        sb.append(String.format("%60.3f%n", (double) tp / n));
155        sb.append(String.format(labelFormatString, "Micro Average"));
156        sb.append(String.format("%60.3f%12.3f%12.3f%n",
157                evaluation.microAveragedRecall(),
158                evaluation.microAveragedPrecision(),
159                evaluation.microAveragedF1()));
160        sb.append(String.format(labelFormatString, "Macro Average"));
161        sb.append(String.format("%60.3f%12.3f%12.3f%n",
162                evaluation.macroAveragedRecall(),
163                evaluation.macroAveragedPrecision(),
164                evaluation.macroAveragedF1()));
165        sb.append(String.format(labelFormatString, "Balanced Error Rate"));
166        sb.append(String.format("%60.3f", evaluation.balancedErrorRate()));
167        return sb.toString();
168    }
169
170    /**
171     * This method produces a HTML formatted String output, with
172     * appropriate tabs and newlines, suitable for integation into a webpage.
173     * It can be used as an implementation of the {@link EvaluationRenderer}
174     * functional interface.
175     * @param evaluation The evaluation to format.
176     * @return Formatted HTML output showing the main results of the evaluation.
177     */
178    public static String toHTML(LabelEvaluation evaluation) {
179        ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
180        List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
181        StringBuilder sb = new StringBuilder();
182        int tp = 0;
183        int fn = 0;
184        int fp = 0;
185        int tn = 0;
186        sb.append("<table>\n");
187        sb.append("<tr>\n");
188        sb.append("<th>Class</th><th>n</th> <th>%</th> <th>tp</th> <th>fn</th> <th>fp</th> <th>Recall</th> <th>Precision</th> <th>F1</th>");
189        sb.append("\n</tr>\n");
190        //
191        // Compute the total number of instances first, so we can show proportions.
192        for (Label label : labelOrder) {
193            //tn += occurrences.getOrDefault(label, 0);
194            tn += cm.tn(label);
195        }
196        for (Label label : labelOrder) {
197            if (cm.support(label) == 0) {
198                continue;
199            }
200            tp += cm.tp(label);
201            fn += cm.fn(label);
202            fp += cm.fp(label);
203            sb.append("<tr>");
204            sb.append("<td><code>").append(label).append("</code></td>");
205            int occurrence = (int) cm.support(label);
206            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", occurrence)).append("</td>");
207            sb.append("<td style=\"text-align:right\">").append(String.format("%8.1f%%", (occurrence/ (double) tn)*100)).append("</td>");
208            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.tp(label))).append("</td>");
209            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fn(label))).append("</td>");
210            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fp(label))).append("</td>");
211            sb.append(String.format("<td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n",
212                    evaluation.recall(label), evaluation.precision(label), evaluation.f1(label)));
213            sb.append("</tr>");
214        }
215        sb.append("<tr>");
216        sb.append("<td>Total</td>");
217        sb.append(String.format("<td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\"></td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td>%n", tn, tp, fn, fp));
218        sb.append("<td colspan=\"4\"></td>");
219        sb.append("</tr>\n<tr>");
220        sb.append(String.format("<td>Accuracy</td><td style=\"text-align:right\" colspan=\"6\">%8.3f</td>%n", evaluation.accuracy()));
221        sb.append("<td colspan=\"4\"></td>");
222        sb.append("</tr>\n<tr>");
223        sb.append("<td>Micro Average</td>");
224        sb.append(String.format("<td style=\"text-align:right\" colspan=\"6\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n",
225                evaluation.microAveragedRecall(),
226                evaluation.microAveragedPrecision(),
227                evaluation.microAveragedF1()));
228        sb.append("</tr></table>");
229        return sb.toString();
230    }
231
232}