001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.classification.Label; 022import org.tribuo.classification.evaluation.LabelMetric.Context; 023import org.tribuo.evaluation.metrics.MetricTarget; 024 025import java.util.List; 026import java.util.function.ToDoubleBiFunction; 027 028/** 029 * An enum of the default {@link LabelMetric}s supported by the multi-class classification 030 * evaluation package. 031 */ 032public enum LabelMetrics { 033 034 /** 035 * The number of true positives. 036 */ 037 TP((tgt, ctx) -> ConfusionMetrics.tp(tgt, ctx.getCM())), 038 /** 039 * The number of false positives. 040 */ 041 FP((tgt, ctx) -> ConfusionMetrics.fp(tgt, ctx.getCM())), 042 /** 043 * The number of true negatives. 044 */ 045 TN((tgt, ctx) -> ConfusionMetrics.tn(tgt, ctx.getCM())), 046 /** 047 * The number of false negatives. 048 */ 049 FN((tgt, ctx) -> ConfusionMetrics.fn(tgt, ctx.getCM())), 050 /** 051 * The precision, i.e., the number of true positives divided by the number of predicted positives. 052 */ 053 PRECISION((tgt, ctx) -> ConfusionMetrics.precision(tgt, ctx.getCM())), 054 /** 055 * The recall, i.e., the number of true positives divided by the number of ground truth positives. 056 */ 057 RECALL((tgt, ctx) -> ConfusionMetrics.recall(tgt, ctx.getCM())), 058 /** 059 * The F_1 score, i.e., the harmonic mean of the precision and the recall. 060 */ 061 F1((tgt, ctx) -> ConfusionMetrics.f1(tgt, ctx.getCM())), 062 /** 063 * The accuracy. 064 */ 065 ACCURACY((tgt, ctx) -> ConfusionMetrics.accuracy(tgt, ctx.getCM())), 066 /** 067 * The balanced error rate, i.e., the mean of the per class recalls. 068 */ 069 BALANCED_ERROR_RATE((tgt, ctx) -> ConfusionMetrics.balancedErrorRate(ctx.getCM())), 070 /** 071 * The area under the receiver-operator curve (ROC). 072 */ 073 AUCROC((tgt, ctx) -> LabelMetrics.AUCROC(tgt, ctx.getPredictions())), 074 /** 075 * The averaged precision. 076 */ 077 AVERAGED_PRECISION((tgt, ctx) -> LabelMetrics.averagedPrecision(tgt, ctx.getPredictions())); 078 079 private final ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl; 080 081 LabelMetrics(ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl) { 082 this.impl = impl; 083 } 084 085 /** 086 * Returns the implementing function for this metric. 087 * @return The implementing function. 088 */ 089 public ToDoubleBiFunction<MetricTarget<Label>, Context> getImpl() { 090 return impl; 091 } 092 093 /** 094 * Gets the LabelMetric wrapped around the supplied MetricTarget. 095 * @param tgt The metric target. 096 * @return The label metric combining the implementation function with the supplied metric target. 097 */ 098 public LabelMetric forTarget(MetricTarget<Label> tgt) { 099 return new LabelMetric(tgt, this.name(), this.getImpl()); 100 } 101 102 /** 103 * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[]) 104 * @param tgt The metric target to use. 105 * @param predictions The predictions to use. 106 * @return The averaged precision for the supplied target with the supplied predictions. 107 */ 108 public static double averagedPrecision(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) { 109 if (tgt.getOutputTarget().isPresent()) { 110 return averagedPrecision(tgt.getOutputTarget().get(), predictions); 111 } else { 112 throw new IllegalStateException("Unsupported MetricTarget for averagedPrecision"); 113 } 114 } 115 116 /** 117 * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[]) 118 * @param label The Label to average across. 119 * @param predictions The predictions to use. 120 * @return The averaged precision for the supplied label with the supplied predictions. 121 */ 122 public static double averagedPrecision(Label label, List<Prediction<Label>> predictions) { 123 PredictionProbabilities record = new PredictionProbabilities(label, predictions); 124 return LabelEvaluationUtil.averagedPrecision(record.ypos, record.yscore); 125 } 126 127 /** 128 * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[]) 129 * @param label The Label to calculate precision and recall for. 130 * @param predictions The predictions to use. 131 * @return The Precision Recall Curve for the supplied label with the supplied predictions. 132 */ 133 public static LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label, List<Prediction<Label>> predictions) { 134 PredictionProbabilities record = new PredictionProbabilities(label, predictions); 135 return LabelEvaluationUtil.generatePRCurve(record.ypos, record.yscore); 136 } 137 138 /** 139 * Area under the ROC curve. 140 * 141 * @param label the label corresponding to the "positive" class 142 * @param predictions the predictions for which we'll compute the score 143 * @return AUC ROC score 144 * @throws UnsupportedOperationException if a prediction with no probability score, which are required to compute the ROC curve. (See also: {@link Model#generatesProbabilities()}) 145 */ 146 public static double AUCROC(Label label, List<Prediction<Label>> predictions) { 147 PredictionProbabilities record = new PredictionProbabilities(label, predictions); 148 return LabelEvaluationUtil.binaryAUCROC(record.ypos, record.yscore); 149 } 150 151 /** 152 * Area under the ROC curve. 153 * 154 * @param tgt The metric target for the positive class. 155 * @param predictions the predictions for which we'll compute the score 156 * @return AUC ROC score 157 * @throws UnsupportedOperationException if a prediction with no probability score, which are required to compute the ROC curve. (See also: {@link Model#generatesProbabilities()}) 158 */ 159 public static double AUCROC(MetricTarget<Label> tgt, List<Prediction<Label>> predictions) { 160 if (tgt.getOutputTarget().isPresent()) { 161 return AUCROC(tgt.getOutputTarget().get(), predictions); 162 } else { 163 throw new IllegalStateException("Unsupported MetricTarget for AUCROC"); 164 } 165 } 166 167 /** 168 * One day, it'll be a record. Not today mind. 169 */ 170 private static final class PredictionProbabilities { 171 final boolean[] ypos; 172 final double[] yscore; 173 PredictionProbabilities(Label label, List<Prediction<Label>> predictions) { 174 int n = predictions.size(); 175 ypos = new boolean[n]; 176 yscore = new double[n]; 177 for (int i = 0; i < n; i++) { 178 Prediction<Label> prediction = predictions.get(i); 179 if (!prediction.hasProbabilities()) { 180 throw new UnsupportedOperationException(String.format("Invalid prediction at index %d: has no probability score.", i)); 181 } 182 if (prediction.getExample().getOutput().equals(label)) { 183 ypos[i] = true; 184 } 185 yscore[i] = prediction 186 .getOutputScores() 187 .get(label.getLabel()) 188 .getScore(); 189 } 190 } 191 } 192 193}