001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import com.oracle.labs.mlrg.olcut.util.SortUtil; 020import org.tribuo.classification.Label; 021import org.tribuo.util.Util; 022 023import java.util.ArrayList; 024import java.util.Collections; 025import java.util.List; 026 027/** 028 * Static utility functions for calculating performance metrics on {@link Label}s. 029 */ 030public final class LabelEvaluationUtil { 031 032 // Static utility class, has private constructor and is final. 033 private LabelEvaluationUtil() {} 034 035 /** 036 * Summarises a Precision-Recall Curve by taking the weighted mean of the 037 * precisions at a given threshold, where the weight is the recall achieved at 038 * that threshold. 039 * 040 * Follows scikit-learn's implementation. 041 * 042 * In general use the AUC for a Precision-Recall Gain curve as the area under 043 * the precision-recall curve is not properly normalized. 044 * @param yPos Each element is true if the label was from the positive class. 045 * @param yScore Each element is the score of the positive class. 046 * @return The averaged precision. 047 */ 048 public static double averagedPrecision(boolean[] yPos, double[] yScore) { 049 PRCurve prc = generatePRCurve(yPos,yScore); 050 051 double score = 0.0; 052 053 for (int i = 0; i < prc.precision.length-1; i++) { 054 score += (prc.recall[i+1] - prc.recall[i]) * prc.precision[i]; 055 } 056 057 return -score; 058 } 059 060 /** 061 * Calculates the Precision Recall curve for a single label. 062 * 063 * In general use Precision-Recall Gain curves. 064 * @param yPos Each element is true if the label was from the positive class. 065 * @param yScore Each element is the score of the positive class. 066 * @return The PRCurve for one label. 067 */ 068 public static PRCurve generatePRCurve(boolean[] yPos, double[] yScore) { 069 TPFP tpfp = generateTPFPs(yPos,yScore); 070 071 ArrayList<Double> precisions = new ArrayList<>(tpfp.falsePos.size()); 072 ArrayList<Double> recalls = new ArrayList<>(tpfp.falsePos.size()); 073 ArrayList<Double> thresholds = new ArrayList<>(tpfp.falsePos.size()); 074 075 for (int i = 0; i < tpfp.falsePos.size(); i++) { 076 double curFalsePos = tpfp.falsePos.get(i); 077 double curTruePos = tpfp.truePos.get(i); 078 079 double precision = 0.0; 080 double recall = 0.0; 081 if (curTruePos != 0) { 082 precision = curTruePos / (curTruePos + curFalsePos); 083 recall = curTruePos / tpfp.totalPos; 084 } 085 086 precisions.add(precision); 087 recalls.add(recall); 088 thresholds.add(tpfp.thresholds.get(i)); 089 090 // Break out if we've achieved full recall. 091 if (curTruePos == tpfp.totalPos) { 092 break; 093 } 094 } 095 096 Collections.reverse(precisions); 097 Collections.reverse(recalls); 098 Collections.reverse(thresholds); 099 100 precisions.add(1.0); 101 recalls.add(0.0); 102 103 return new PRCurve(Util.toPrimitiveDouble(precisions),Util.toPrimitiveDouble(recalls),Util.toPrimitiveDouble(thresholds)); 104 } 105 106 /** 107 * Calculates the area under the receiver operator characteristic curve, 108 * i.e., the AUC of the ROC curve. 109 * @param yPos Is the associated index a positive label. 110 * @param yScore The score of the positive class. 111 * @return The auc (a value bounded 0.0-1.0). 112 */ 113 public static double binaryAUCROC(boolean[] yPos, double[] yScore) { 114 ROC roc = generateROCCurve(yPos, yScore); 115 return Util.auc(roc.fpr, roc.tpr); 116 } 117 118 /** 119 * Calculates the binary ROC for a single label. 120 * @param yPos Each element is true if the label was from the positive class. 121 * @param yScore Each element is the score of the positive class. 122 * @return The ROC for one label. 123 */ 124 public static ROC generateROCCurve(boolean[] yPos, double[] yScore) { 125 TPFP tpfp = generateTPFPs(yPos,yScore); 126 127 // If it doesn't exist, add a 0,0 point so the graph always starts from the origin. 128 // This point has a threshold of POSITIVE_INFINITY as it's the always negative classifier. 129 if ((tpfp.truePos.get(0) != 0) || (tpfp.falsePos.get(0) != 0)) { 130 tpfp.truePos.add(0,0); 131 tpfp.falsePos.add(0,0); 132 tpfp.thresholds.add(0,Double.POSITIVE_INFINITY); // Set threshold to positive infinity 133 } 134 135 // Transform things back into arrays. 136 double[] truePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.truePos); 137 double[] falsePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.falsePos); 138 double[] thresholdsArr = Util.toPrimitiveDouble(tpfp.thresholds); 139 140 // Convert from counts into a rate. 141 double maxTrue = truePosArr[truePosArr.length-1]; 142 double maxFalse = falsePosArr[falsePosArr.length-1]; 143 for (int i = 0; i < truePosArr.length; i++) { 144 truePosArr[i] /= maxTrue; 145 falsePosArr[i] /= maxFalse; 146 } 147 148 return new ROC(falsePosArr,truePosArr,thresholdsArr); 149 } 150 151 private static TPFP generateTPFPs(boolean[] yPos, double[] yScore) { 152 if (yPos.length != yScore.length) { 153 throw new IllegalArgumentException("yPos and yScore must be the same length, yPos.length = " + yPos.length + ", yScore.length = " + yScore.length); 154 } 155 // First sort the predictions by their score 156 // and apply that sort to the true labels and the predictions. 157 int[] sortedIndices = SortUtil.argsort(yScore,false); 158 double[] sortedScore = new double[yScore.length]; 159 boolean[] sortedPos = new boolean[yPos.length]; 160 int totalPos = 0; 161 for (int i = 0; i < yScore.length; i++) { 162 sortedScore[i] = yScore[sortedIndices[i]]; 163 sortedPos[i] = yPos[sortedIndices[i]]; 164 if (sortedPos[i]) { 165 totalPos++; 166 } 167 } 168 169 // Find all the differences in the score values as values with 170 // the same score need to be compressed into a single ROC point. 171 int[] differentIndices = Util.differencesIndices(sortedScore); 172 int[] truePosSum = Util.cumulativeSum(sortedPos); 173 174 // Calculate the number of true positives and false positives for each score threshold. 175 ArrayList<Integer> truePos = new ArrayList<>(); 176 ArrayList<Integer> falsePos = new ArrayList<>(); 177 ArrayList<Double> thresholds = new ArrayList<>(); 178 for (int i = 0; i < differentIndices.length; i++) { 179 thresholds.add(sortedScore[differentIndices[i]]); 180 truePos.add(truePosSum[differentIndices[i]]); 181 falsePos.add(1+(differentIndices[i] - truePosSum[differentIndices[i]])); 182 } 183 184 return new TPFP(falsePos,truePos,thresholds,totalPos); 185 } 186 187 private static class TPFP { 188 public final List<Integer> falsePos; 189 public final List<Integer> truePos; 190 public final List<Double> thresholds; 191 public final int totalPos; 192 193 public TPFP(List<Integer> falsePos, List<Integer> truePos, List<Double> thresholds, int totalPos) { 194 this.falsePos = falsePos; 195 this.truePos = truePos; 196 this.thresholds = thresholds; 197 this.totalPos = totalPos; 198 } 199 } 200 201 /** 202 * Stores the ROC curve as three arrays: the false positive rate, the true positive rate, 203 * and the thresholds associated with those rates. 204 * 205 * By definition if both tpr and fpr are zero for the first value, the threshold is positive infinity. 206 */ 207 public static class ROC { 208 public final double[] fpr; 209 public final double[] tpr; 210 public final double[] thresholds; 211 212 public ROC(double[] fpr, double[] tpr, double[] thresholds) { 213 this.fpr = fpr; 214 this.tpr = tpr; 215 this.thresholds = thresholds; 216 } 217 } 218 219 /** 220 * Stores the Precision-Recall curve as three arrays: the precisions, the recalls, 221 * and the thresholds associated with those values. 222 */ 223 public static class PRCurve { 224 public final double[] precision; 225 public final double[] recall; 226 public final double[] thresholds; 227 228 public PRCurve(double[] precision, double[] recall, double[] thresholds) { 229 this.precision = precision; 230 this.recall = recall; 231 this.thresholds = thresholds; 232 } 233 } 234}