001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.evaluation; 018 019import org.tribuo.Prediction; 020import org.tribuo.classification.evaluation.ConfusionMatrix; 021import org.tribuo.evaluation.metrics.EvaluationMetric.Average; 022import org.tribuo.evaluation.metrics.MetricID; 023import org.tribuo.evaluation.metrics.MetricTarget; 024import org.tribuo.multilabel.MultiLabel; 025import org.tribuo.provenance.EvaluationProvenance; 026 027import java.util.ArrayList; 028import java.util.Collections; 029import java.util.List; 030import java.util.Map; 031 032 033/** 034 * The implementation of a {@link MultiLabelEvaluation} using the default metrics. 035 * <p> 036 * The classification metrics consider labels independently. 037 */ 038public final class MultiLabelEvaluationImpl implements MultiLabelEvaluation { 039 040 private final Map<MetricID<MultiLabel>, Double> results; 041 private final MultiLabelMetric.Context context; 042 private final ConfusionMatrix<MultiLabel> cm; 043 private final EvaluationProvenance provenance; 044 045 /** 046 * Builds an evaluation using the supplied metric results, confusion matrix and evaluation provenance. 047 * @param results The results. 048 * @param context The context carrying the confusion matrix 049 * @param provenance The evaluation provenance. 050 */ 051 MultiLabelEvaluationImpl(Map<MetricID<MultiLabel>, Double> results, 052 MultiLabelMetric.Context context, 053 EvaluationProvenance provenance) { 054 this.results = results; 055 this.context = context; 056 this.cm = context.getCM(); 057 this.provenance = provenance; 058 } 059 060 @Override 061 public List<Prediction<MultiLabel>> getPredictions() { 062 return context.getPredictions(); 063 } 064 065 @Override 066 public double balancedErrorRate() { 067 // Target doesn't matter for balanced error rate, so we just use Average.macro 068 // as it's the macro average of the recalls. 069 MetricTarget<MultiLabel> dummy = MetricTarget.macroAverageTarget(); 070 return get(dummy, MultiLabelMetrics.BALANCED_ERROR_RATE); 071 } 072 073 @Override 074 public ConfusionMatrix<MultiLabel> getConfusionMatrix() { 075 return cm; 076 } 077 078 @Override 079 public double confusion(MultiLabel predicted, MultiLabel truth) { 080 return cm.confusion(predicted, truth); 081 } 082 083 @Override 084 public double tp(MultiLabel label) { 085 return get(label, MultiLabelMetrics.TP); 086 } 087 088 @Override 089 public double tp() { 090 return get(Average.MICRO, MultiLabelMetrics.TP); 091 } 092 093 @Override 094 public double macroTP() { 095 return get(Average.MACRO, MultiLabelMetrics.TP); 096 } 097 098 @Override 099 public double fp(MultiLabel label) { 100 return get(label, MultiLabelMetrics.FP); 101 } 102 103 @Override 104 public double fp() { 105 return get(Average.MICRO, MultiLabelMetrics.FP); 106 } 107 108 @Override 109 public double macroFP() { 110 return get(Average.MACRO, MultiLabelMetrics.FP); 111 } 112 113 @Override 114 public double tn(MultiLabel label) { 115 return get(label, MultiLabelMetrics.TN); 116 } 117 118 @Override 119 public double tn() { 120 return get(Average.MICRO, MultiLabelMetrics.TN); 121 } 122 123 @Override 124 public double macroTN() { 125 return get(Average.MACRO, MultiLabelMetrics.TN); 126 } 127 128 @Override 129 public double fn(MultiLabel label) { 130 return get(label, MultiLabelMetrics.FN); 131 } 132 133 @Override 134 public double fn() { 135 return get(Average.MICRO, MultiLabelMetrics.FN); 136 } 137 138 @Override 139 public double macroFN() { 140 return get(Average.MACRO, MultiLabelMetrics.FN); 141 } 142 143 @Override 144 public double precision(MultiLabel label) { 145 return get(new MetricTarget<>(label), MultiLabelMetrics.PRECISION); 146 } 147 148 @Override 149 public double microAveragedPrecision() { 150 return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.PRECISION); 151 } 152 153 @Override 154 public double macroAveragedPrecision() { 155 return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.PRECISION); 156 } 157 158 @Override 159 public double recall(MultiLabel label) { 160 return get(new MetricTarget<>(label), MultiLabelMetrics.RECALL); 161 } 162 163 @Override 164 public double microAveragedRecall() { 165 return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.RECALL); 166 } 167 168 @Override 169 public double macroAveragedRecall() { 170 return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.RECALL); 171 } 172 173 @Override 174 public double f1(MultiLabel label) { 175 return get(new MetricTarget<>(label), MultiLabelMetrics.F1); 176 } 177 178 @Override 179 public double microAveragedF1() { 180 return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.F1); 181 } 182 183 @Override 184 public double macroAveragedF1() { 185 return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.F1); 186 } 187 188 @Override 189 public Map<MetricID<MultiLabel>, Double> asMap() { 190 return Collections.unmodifiableMap(results); 191 } 192 193 @Override 194 public EvaluationProvenance getProvenance() { 195 return provenance; 196 } 197 198 @Override 199 public String toString() { 200 List<MultiLabel> labelOrder = new ArrayList<>(cm.getDomain().getDomain()); 201 return toString(labelOrder); 202 } 203 204 private String toString(List<MultiLabel> labelOrder) { 205 StringBuilder sb = new StringBuilder(); 206 int tp = 0; 207 int fn = 0; 208 int fp = 0; 209 int n = 0; 210 // 211 // Figure out the biggest class label and therefore the format string 212 // that we should use for them. 213 int maxLabelSize = "Balanced Error Rate".length(); 214 for(MultiLabel label : labelOrder) { 215 maxLabelSize = Math.max(maxLabelSize, label.getLabelString().length()); 216 } 217 String labelFormatString = String.format("%%-%ds", maxLabelSize+2); 218 sb.append(String.format(labelFormatString, "Class")); 219 sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp")); 220 sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1")); 221 for (MultiLabel label : labelOrder) { 222 if (cm.support(label) == 0) { 223 continue; 224 } 225 n += cm.support(label); 226 tp += cm.tp(label); 227 fn += cm.fn(label); 228 fp += cm.fp(label); 229 sb.append(String.format(labelFormatString, label)); 230 sb.append(String.format("%,12d%,12d%,12d%,12d", 231 (int) cm.support(label), 232 (int) cm.tp(label), 233 (int) cm.fn(label), 234 (int) cm.fp(label) 235 )); 236 sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label))); 237 } 238 sb.append(String.format(labelFormatString, "Total")); 239 sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp)); 240 sb.append(String.format(labelFormatString, "Accuracy")); 241 sb.append(String.format("%60.3f%n", ((double) tp) / n)); 242 sb.append(String.format(labelFormatString, "Micro Average")); 243 sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1())); 244 sb.append(String.format(labelFormatString, "Macro Average")); 245 sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1())); 246 sb.append(String.format(labelFormatString, "Balanced Error Rate")); 247 sb.append(String.format("%60.3f", balancedErrorRate())); 248 return sb.toString(); 249 } 250 251 private double get(MetricTarget<MultiLabel> tgt, MultiLabelMetrics metric) { 252 return get(metric.forTarget(tgt).getID()); 253 } 254 255 private double get(MultiLabel label, MultiLabelMetrics metric) { 256 return get(metric 257 .forTarget(new MetricTarget<>(label)) 258 .getID()); 259 } 260 261 private double get(Average avg, MultiLabelMetrics metric) { 262 return get(metric 263 .forTarget(new MetricTarget<>(avg)) 264 .getID()); 265 } 266 267}