001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence; 018 019import org.tribuo.Prediction; 020import org.tribuo.classification.Label; 021import org.tribuo.classification.evaluation.ConfusionMatrix; 022import org.tribuo.classification.evaluation.LabelMetric; 023import org.tribuo.classification.evaluation.LabelMetrics; 024import org.tribuo.evaluation.metrics.EvaluationMetric; 025import org.tribuo.evaluation.metrics.MetricID; 026import org.tribuo.evaluation.metrics.MetricTarget; 027import org.tribuo.provenance.EvaluationProvenance; 028import org.tribuo.sequence.SequenceEvaluation; 029 030import java.util.ArrayList; 031import java.util.Collections; 032import java.util.List; 033import java.util.Map; 034import java.util.logging.Logger; 035 036/** 037 * A class that can be used to evaluate a sequence label classification model element wise on a given set of data. 038 */ 039public class LabelSequenceEvaluation implements SequenceEvaluation<Label> { 040 041 private static final Logger logger = Logger.getLogger(LabelSequenceEvaluation.class.getName()); 042 043 private final Map<MetricID<Label>, Double> results; 044 private final LabelMetric.Context ctx; 045 private final ConfusionMatrix<Label> cm; 046 private final EvaluationProvenance provenance; 047 048 protected LabelSequenceEvaluation(Map<MetricID<Label>, Double> results, 049 LabelMetric.Context ctx, 050 EvaluationProvenance provenance) { 051 this.results = results; 052 this.ctx = ctx; 053 this.cm = ctx.getCM(); 054 this.provenance = provenance; 055 } 056 057 /** 058 * Gets the flattened predictions. 059 * @return The flattened predictions. 060 */ 061 public List<Prediction<Label>> getPredictions() { 062 return ctx.getPredictions(); 063 } 064 065 /** 066 * Gets the confusion matrix backing this evaluation. 067 * @return The confusion matrix. 068 */ 069 public ConfusionMatrix<Label> getConfusionMatrix() { 070 return cm; 071 } 072 073 @Override 074 public Map<MetricID<Label>, Double> asMap() { 075 return Collections.unmodifiableMap(results); 076 } 077 078 /** 079 * Note: confusion is not stored in the underlying map, so it won't show up in aggregation. 080 * @param predictedLabel The predicted label. 081 * @param trueLabel The true label. 082 * @return The number of times that {@code predictedLabel} was predicted for <code>trueLabel</code>. 083 */ 084 public double confusion(Label predictedLabel, Label trueLabel) { 085 return cm.confusion(predictedLabel, trueLabel); 086 } 087 088 public double tp(Label label) { 089 return get(label, LabelMetrics.TP); 090 } 091 092 public double tp() { 093 return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP); 094 } 095 096 public double macroTP() { 097 return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP); 098 } 099 100 public double fp(Label label) { 101 return get(label, LabelMetrics.FP); 102 } 103 104 public double fp() { 105 return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP); 106 } 107 108 public double macroFP() { 109 return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP); 110 } 111 112 public double tn(Label label) { 113 return get(label, LabelMetrics.TN); 114 } 115 116 public double tn() { 117 return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN); 118 } 119 120 public double macroTN() { 121 return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN); 122 } 123 124 public double fn(Label label) { 125 return get(label, LabelMetrics.FN); 126 } 127 128 public double fn() { 129 return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN); 130 } 131 132 public double macroFN() { 133 return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN); 134 } 135 136 public double precision(Label label) { 137 return get(label, LabelMetrics.PRECISION); 138 } 139 140 public double microAveragedPrecision() { 141 return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION); 142 } 143 144 public double macroAveragedPrecision() { 145 return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION); 146 } 147 148 public double recall(Label label) { 149 return get(label, LabelMetrics.RECALL); 150 } 151 152 public double microAveragedRecall() { 153 return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL); 154 } 155 156 public double macroAveragedRecall() { 157 return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL); 158 } 159 160 public double f1(Label label) { 161 return get(label, LabelMetrics.RECALL); 162 } 163 164 public double microAveragedF1() { 165 return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1); 166 } 167 168 public double macroAveragedF1() { 169 return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1); 170 } 171 172 public double accuracy() { 173 return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY); 174 } 175 176 public double accuracy(Label label) { 177 return get(label, LabelMetrics.ACCURACY); 178 } 179 180 public double balancedErrorRate() { 181 // Target doesn't matter for balanced error rate, so we just use Average.macro 182 // as it's the macro averaged recall. 183 return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE); 184 } 185 186 @Override 187 public EvaluationProvenance getProvenance() { return provenance; } 188 189 @Override 190 public String toString() { 191 List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain()); 192 StringBuilder sb = new StringBuilder(); 193 int tp = 0; 194 int fn = 0; 195 int fp = 0; 196 int n = 0; 197 // 198 // Figure out the biggest class label and therefore the format string 199 // that we should use for them. 200 int maxLabelSize = "Balanced Error Rate".length(); 201 for(Label label : labelOrder) { 202 maxLabelSize = Math.max(maxLabelSize, label.getLabel().length()); 203 } 204 String labelFormatString = String.format("%%-%ds", maxLabelSize+2); 205 sb.append(String.format(labelFormatString, "Class")); 206 sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp")); 207 sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1")); 208 for (Label label : labelOrder) { 209 if (cm.support(label) == 0) { 210 continue; 211 } 212 n += cm.support(label); 213 tp += cm.tp(label); 214 fn += cm.fn(label); 215 fp += cm.fp(label); 216 sb.append(String.format(labelFormatString, label)); 217 sb.append(String.format("%,12d%,12d%,12d%,12d", 218 (int) cm.support(label), 219 (int) cm.tp(label), 220 (int) cm.fn(label), 221 (int) cm.fp(label) 222 )); 223 sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label))); 224 } 225 sb.append(String.format(labelFormatString, "Total")); 226 sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp)); 227 sb.append(String.format(labelFormatString, "Accuracy")); 228 sb.append(String.format("%60.3f%n", (double) tp / n)); 229 sb.append(String.format(labelFormatString, "Micro Average")); 230 sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1())); 231 sb.append(String.format(labelFormatString, "Macro Average")); 232 sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1())); 233 sb.append(String.format(labelFormatString, "Balanced Error Rate")); 234 sb.append(String.format("%60.3f", balancedErrorRate())); 235 return sb.toString(); 236 } 237 238 private double get(MetricTarget<Label> tgt, LabelMetrics metric) { 239 return get(metric.forTarget(tgt).getID()); 240 } 241 242 private double get(Label label, LabelMetrics metric) { 243 return get(metric 244 .forTarget(new MetricTarget<>(label)) 245 .getID()); 246 } 247 248 private double get(EvaluationMetric.Average avg, LabelMetrics metric) { 249 return get(metric 250 .forTarget(new MetricTarget<>(avg)) 251 .getID()); 252 } 253}