001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import com.oracle.labs.mlrg.olcut.util.SortUtil;
020import org.tribuo.classification.Label;
021import org.tribuo.util.Util;
022
023import java.util.ArrayList;
024import java.util.Collections;
025import java.util.List;
026
027/**
028 * Static utility functions for calculating performance metrics on {@link Label}s.
029 */
030public final class LabelEvaluationUtil {
031
032    // Static utility class, has private constructor and is final.
033    private LabelEvaluationUtil() {}
034
035    /**
036     * Summarises a Precision-Recall Curve by taking the weighted mean of the
037     * precisions at a given threshold, where the weight is the recall achieved at
038     * that threshold.
039     *
040     * Follows scikit-learn's implementation.
041     *
042     * In general use the AUC for a Precision-Recall Gain curve as the area under
043     * the precision-recall curve is not properly normalized.
044     * @param yPos Each element is true if the label was from the positive class.
045     * @param yScore Each element is the score of the positive class.
046     * @return The averaged precision.
047     */
048    public static double averagedPrecision(boolean[] yPos, double[] yScore) {
049        PRCurve prc = generatePRCurve(yPos,yScore);
050
051        double score = 0.0;
052
053        for (int i = 0; i < prc.precision.length-1; i++) {
054            score += (prc.recall[i+1] - prc.recall[i]) * prc.precision[i];
055        }
056
057        return -score;
058    }
059
060    /**
061     * Calculates the Precision Recall curve for a single label.
062     *
063     * In general use Precision-Recall Gain curves.
064     * @param yPos Each element is true if the label was from the positive class.
065     * @param yScore Each element is the score of the positive class.
066     * @return The PRCurve for one label.
067     */
068    public static PRCurve generatePRCurve(boolean[] yPos, double[] yScore) {
069        TPFP tpfp = generateTPFPs(yPos,yScore);
070
071        ArrayList<Double> precisions = new ArrayList<>(tpfp.falsePos.size());
072        ArrayList<Double> recalls = new ArrayList<>(tpfp.falsePos.size());
073        ArrayList<Double> thresholds = new ArrayList<>(tpfp.falsePos.size());
074
075        for (int i = 0; i < tpfp.falsePos.size(); i++) {
076            double curFalsePos = tpfp.falsePos.get(i);
077            double curTruePos = tpfp.truePos.get(i);
078
079            double precision = 0.0;
080            double recall = 0.0;
081            if (curTruePos != 0) {
082                precision = curTruePos / (curTruePos + curFalsePos);
083                recall = curTruePos / tpfp.totalPos;
084            }
085
086            precisions.add(precision);
087            recalls.add(recall);
088            thresholds.add(tpfp.thresholds.get(i));
089
090            // Break out if we've achieved full recall.
091            if (curTruePos == tpfp.totalPos) {
092                break;
093            }
094        }
095
096        Collections.reverse(precisions);
097        Collections.reverse(recalls);
098        Collections.reverse(thresholds);
099
100        precisions.add(1.0);
101        recalls.add(0.0);
102
103        return new PRCurve(Util.toPrimitiveDouble(precisions),Util.toPrimitiveDouble(recalls),Util.toPrimitiveDouble(thresholds));
104    }
105
106    /**
107     * Calculates the area under the receiver operator characteristic curve,
108     * i.e., the AUC of the ROC curve.
109     * @param yPos Is the associated index a positive label.
110     * @param yScore The score of the positive class.
111     * @return The auc (a value bounded 0.0-1.0).
112     */
113    public static double binaryAUCROC(boolean[] yPos, double[] yScore) {
114        ROC roc = generateROCCurve(yPos, yScore);
115        return Util.auc(roc.fpr, roc.tpr);
116    }
117
118    /**
119     * Calculates the binary ROC for a single label.
120     * @param yPos Each element is true if the label was from the positive class.
121     * @param yScore Each element is the score of the positive class.
122     * @return The ROC for one label.
123     */
124    public static ROC generateROCCurve(boolean[] yPos, double[] yScore) {
125        TPFP tpfp = generateTPFPs(yPos,yScore);
126
127        // If it doesn't exist, add a 0,0 point so the graph always starts from the origin.
128        // This point has a threshold of POSITIVE_INFINITY as it's the always negative classifier.
129        if ((tpfp.truePos.get(0) != 0) || (tpfp.falsePos.get(0) != 0)) {
130            tpfp.truePos.add(0,0);
131            tpfp.falsePos.add(0,0);
132            tpfp.thresholds.add(0,Double.POSITIVE_INFINITY); // Set threshold to positive infinity
133        }
134
135        // Transform things back into arrays.
136        double[] truePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.truePos);
137        double[] falsePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.falsePos);
138        double[] thresholdsArr = Util.toPrimitiveDouble(tpfp.thresholds);
139
140        // Convert from counts into a rate.
141        double maxTrue = truePosArr[truePosArr.length-1];
142        double maxFalse = falsePosArr[falsePosArr.length-1];
143        for (int i = 0; i < truePosArr.length; i++) {
144            truePosArr[i] /= maxTrue;
145            falsePosArr[i] /= maxFalse;
146        }
147
148        return new ROC(falsePosArr,truePosArr,thresholdsArr);
149    }
150
151    private static TPFP generateTPFPs(boolean[] yPos, double[] yScore) {
152        if (yPos.length != yScore.length) {
153            throw new IllegalArgumentException("yPos and yScore must be the same length, yPos.length = " + yPos.length + ", yScore.length = " + yScore.length);
154        }
155        // First sort the predictions by their score
156        // and apply that sort to the true labels and the predictions.
157        int[] sortedIndices = SortUtil.argsort(yScore,false);
158        double[] sortedScore = new double[yScore.length];
159        boolean[] sortedPos = new boolean[yPos.length];
160        int totalPos = 0;
161        for (int i = 0; i < yScore.length; i++) {
162            sortedScore[i] = yScore[sortedIndices[i]];
163            sortedPos[i] = yPos[sortedIndices[i]];
164            if (sortedPos[i]) {
165                totalPos++;
166            }
167        }
168
169        // Find all the differences in the score values as values with
170        // the same score need to be compressed into a single ROC point.
171        int[] differentIndices = Util.differencesIndices(sortedScore);
172        int[] truePosSum = Util.cumulativeSum(sortedPos);
173
174        // Calculate the number of true positives and false positives for each score threshold.
175        ArrayList<Integer> truePos = new ArrayList<>();
176        ArrayList<Integer> falsePos = new ArrayList<>();
177        ArrayList<Double> thresholds = new ArrayList<>();
178        for (int i = 0; i < differentIndices.length; i++) {
179            thresholds.add(sortedScore[differentIndices[i]]);
180            truePos.add(truePosSum[differentIndices[i]]);
181            falsePos.add(1+(differentIndices[i] - truePosSum[differentIndices[i]]));
182        }
183
184        return new TPFP(falsePos,truePos,thresholds,totalPos);
185    }
186
187    private static class TPFP {
188        public final List<Integer> falsePos;
189        public final List<Integer> truePos;
190        public final List<Double> thresholds;
191        public final int totalPos;
192
193        public TPFP(List<Integer> falsePos, List<Integer> truePos, List<Double> thresholds, int totalPos) {
194            this.falsePos = falsePos;
195            this.truePos = truePos;
196            this.thresholds = thresholds;
197            this.totalPos = totalPos;
198        }
199    }
200
201    /**
202     * Stores the ROC curve as three arrays: the false positive rate, the true positive rate,
203     * and the thresholds associated with those rates.
204     *
205     * By definition if both tpr and fpr are zero for the first value, the threshold is positive infinity.
206     */
207    public static class ROC {
208        public final double[] fpr;
209        public final double[] tpr;
210        public final double[] thresholds;
211
212        public ROC(double[] fpr, double[] tpr, double[] thresholds) {
213            this.fpr = fpr;
214            this.tpr = tpr;
215            this.thresholds = thresholds;
216        }
217    }
218
219    /**
220     * Stores the Precision-Recall curve as three arrays: the precisions, the recalls,
221     * and the thresholds associated with those values.
222     */
223    public static class PRCurve {
224        public final double[] precision;
225        public final double[] recall;
226        public final double[] thresholds;
227
228        public PRCurve(double[] precision, double[] recall, double[] thresholds) {
229            this.precision = precision;
230            this.recall = recall;
231            this.thresholds = thresholds;
232        }
233    }
234}