001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.classification.Label; 020import org.tribuo.evaluation.EvaluationRenderer; 021 022import java.util.ArrayList; 023import java.util.List; 024 025/** 026 * Adds multi-class classification specific metrics to {@link ClassifierEvaluation}. 027 */ 028public interface LabelEvaluation extends ClassifierEvaluation<Label> { 029 030 /** 031 * The overall accuracy of the evaluation. 032 * @return The accuracy. 033 */ 034 double accuracy(); 035 036 /** 037 * The per label accuracy of the evaluation. 038 * @param label The target label. 039 * @return The per label accuracy. 040 */ 041 double accuracy(Label label); 042 043 /** 044 * Area under the ROC curve. 045 * 046 * @param label target label 047 * @return AUC ROC score 048 * 049 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 050 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 051 */ 052 double AUCROC(Label label); 053 054 /** 055 * Area under the ROC curve averaged across labels. 056 * <p> 057 * If {@code weighted} is false, use a macro average, if true, weight by the evaluation's observed class counts. 058 * </p> 059 * 060 * @param weighted If true weight by the class counts, if false use a macro average. 061 * @return The average AUCROC. 062 * 063 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 064 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 065 */ 066 double averageAUCROC(boolean weighted); 067 068 /** 069 * Summarises a Precision-Recall Curve by taking the weighted mean of the 070 * precisions at a given threshold, where the weight is the recall achieved at 071 * that threshold. 072 * 073 * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[]) 074 * 075 * @param label The target label. 076 * @return The averaged precision for that label. 077 * 078 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 079 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 080 */ 081 double averagedPrecision(Label label); 082 083 /** 084 * Calculates the Precision Recall curve for a single label. 085 * 086 * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[]) 087 * 088 * @param label The target label. 089 * @return The precision recall curve for that label. 090 * 091 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 092 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 093 */ 094 LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label); 095 096 /** 097 * Returns a HTML formatted String representing this evaluation. 098 * @return A HTML formatted String. 099 */ 100 default String toHTML() { 101 return LabelEvaluation.toHTML(this); 102 } 103 104 /** 105 * This method produces a nicely formatted String output, with 106 * appropriate tabs and newlines, suitable for display on a terminal. 107 * It can be used as an implementation of the {@link EvaluationRenderer} 108 * functional interface. 109 * @param evaluation The evaluation to format. 110 * @return Formatted output showing the main results of the evaluation. 111 */ 112 public static String toFormattedString(LabelEvaluation evaluation) { 113 ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix(); 114 List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain()); 115 StringBuilder sb = new StringBuilder(); 116 int tp = 0; 117 int fn = 0; 118 int fp = 0; 119 int n = 0; 120 // 121 // Figure out the biggest class label and therefore the format string 122 // that we should use for them. 123 int maxLabelSize = "Balanced Error Rate".length(); 124 for(Label label : labelOrder) { 125 maxLabelSize = Math.max(maxLabelSize, label.getLabel().length()); 126 } 127 String labelFormatString = String.format("%%-%ds", maxLabelSize+2); 128 sb.append(String.format(labelFormatString, "Class")); 129 sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp")); 130 sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1")); 131 for (Label label : labelOrder) { 132 if (cm.support(label) == 0) { 133 continue; 134 } 135 n += cm.support(label); 136 tp += cm.tp(label); 137 fn += cm.fn(label); 138 fp += cm.fp(label); 139 sb.append(String.format(labelFormatString, label)); 140 sb.append(String.format("%,12d%,12d%,12d%,12d", 141 (int) cm.support(label), 142 (int) cm.tp(label), 143 (int) cm.fn(label), 144 (int) cm.fp(label) 145 )); 146 sb.append(String.format("%12.3f%12.3f%12.3f%n", 147 evaluation.recall(label), 148 evaluation.precision(label), 149 evaluation.f1(label))); 150 } 151 sb.append(String.format(labelFormatString, "Total")); 152 sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp)); 153 sb.append(String.format(labelFormatString, "Accuracy")); 154 sb.append(String.format("%60.3f%n", (double) tp / n)); 155 sb.append(String.format(labelFormatString, "Micro Average")); 156 sb.append(String.format("%60.3f%12.3f%12.3f%n", 157 evaluation.microAveragedRecall(), 158 evaluation.microAveragedPrecision(), 159 evaluation.microAveragedF1())); 160 sb.append(String.format(labelFormatString, "Macro Average")); 161 sb.append(String.format("%60.3f%12.3f%12.3f%n", 162 evaluation.macroAveragedRecall(), 163 evaluation.macroAveragedPrecision(), 164 evaluation.macroAveragedF1())); 165 sb.append(String.format(labelFormatString, "Balanced Error Rate")); 166 sb.append(String.format("%60.3f", evaluation.balancedErrorRate())); 167 return sb.toString(); 168 } 169 170 /** 171 * This method produces a HTML formatted String output, with 172 * appropriate tabs and newlines, suitable for integation into a webpage. 173 * It can be used as an implementation of the {@link EvaluationRenderer} 174 * functional interface. 175 * @param evaluation The evaluation to format. 176 * @return Formatted HTML output showing the main results of the evaluation. 177 */ 178 public static String toHTML(LabelEvaluation evaluation) { 179 ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix(); 180 List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain()); 181 StringBuilder sb = new StringBuilder(); 182 int tp = 0; 183 int fn = 0; 184 int fp = 0; 185 int tn = 0; 186 sb.append("<table>\n"); 187 sb.append("<tr>\n"); 188 sb.append("<th>Class</th><th>n</th> <th>%</th> <th>tp</th> <th>fn</th> <th>fp</th> <th>Recall</th> <th>Precision</th> <th>F1</th>"); 189 sb.append("\n</tr>\n"); 190 // 191 // Compute the total number of instances first, so we can show proportions. 192 for (Label label : labelOrder) { 193 //tn += occurrences.getOrDefault(label, 0); 194 tn += cm.tn(label); 195 } 196 for (Label label : labelOrder) { 197 if (cm.support(label) == 0) { 198 continue; 199 } 200 tp += cm.tp(label); 201 fn += cm.fn(label); 202 fp += cm.fp(label); 203 sb.append("<tr>"); 204 sb.append("<td><code>").append(label).append("</code></td>"); 205 int occurrence = (int) cm.support(label); 206 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", occurrence)).append("</td>"); 207 sb.append("<td style=\"text-align:right\">").append(String.format("%8.1f%%", (occurrence/ (double) tn)*100)).append("</td>"); 208 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.tp(label))).append("</td>"); 209 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fn(label))).append("</td>"); 210 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fp(label))).append("</td>"); 211 sb.append(String.format("<td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n", 212 evaluation.recall(label), evaluation.precision(label), evaluation.f1(label))); 213 sb.append("</tr>"); 214 } 215 sb.append("<tr>"); 216 sb.append("<td>Total</td>"); 217 sb.append(String.format("<td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\"></td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td>%n", tn, tp, fn, fp)); 218 sb.append("<td colspan=\"4\"></td>"); 219 sb.append("</tr>\n<tr>"); 220 sb.append(String.format("<td>Accuracy</td><td style=\"text-align:right\" colspan=\"6\">%8.3f</td>%n", evaluation.accuracy())); 221 sb.append("<td colspan=\"4\"></td>"); 222 sb.append("</tr>\n<tr>"); 223 sb.append("<td>Micro Average</td>"); 224 sb.append(String.format("<td style=\"text-align:right\" colspan=\"6\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n", 225 evaluation.microAveragedRecall(), 226 evaluation.microAveragedPrecision(), 227 evaluation.microAveragedF1())); 228 sb.append("</tr></table>"); 229 return sb.toString(); 230 } 231 232}