001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.evaluation; 018 019import org.tribuo.evaluation.metrics.EvaluationMetric; 020import org.tribuo.evaluation.metrics.MetricTarget; 021import org.tribuo.regression.Regressor; 022import org.tribuo.util.Util; 023 024import java.util.function.BiFunction; 025import java.util.function.ToDoubleBiFunction; 026 027/** 028 * An enum of the default {@link RegressionMetric}s supported by the multi-dimensional regression 029 * evaluation package. 030 * <p> 031 * The metrics treat each regressed dimension independently. 032 */ 033public enum RegressionMetrics { 034 035 /** 036 * Calculates the R^2 of the predictions. 037 */ 038 R2((target, context) -> RegressionMetrics.r2(target, context.getMemo())), 039 /** 040 * Calculates the Root Mean Squared Error of the predictions. 041 */ 042 RMSE((target, context) -> RegressionMetrics.rmse(target, context.getMemo())), 043 /** 044 * Calculates the Mean Absolute Error of the predictions. 045 */ 046 MAE((target, context) -> RegressionMetrics.mae(target, context.getMemo())), 047 /** 048 * Calculates the Explained Variance of the predictions. 049 */ 050 EV((target, context) -> RegressionMetrics.explainedVariance(target, context.getMemo())); 051 052 private final ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl; 053 RegressionMetrics(ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl) { 054 this.impl = impl; 055 } 056 057 RegressionMetric forTarget(MetricTarget<Regressor> target) { 058 return new RegressionMetric(target, this.name(), this.impl); 059 } 060 061 /** 062 * Calculates R^2 based on the supplied statistics. 063 * @param target The regression dimension or average to target. 064 * @param sufficientStats The sufficient statistics. 065 * @return The R^2 value of the predictions. 066 */ 067 public static double r2(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) { 068 return compute(target, sufficientStats, RegressionMetrics::r2); 069 } 070 071 /** 072 * Calculates R^2 based on the supplied statistics for a single dimension. 073 * @param variable The regression dimension. 074 * @param sufficientStats The sufficient statistics. 075 * @return The R^2 value of the predictions. 076 */ 077 public static double r2(Regressor variable, RegressionSufficientStatistics sufficientStats) { 078 String varname = variable.getNames()[0]; 079 double[] trueArray = sufficientStats.trueValues.get(varname); 080 double numerator = sufficientStats.sumSquaredError.get(varname).doubleValue(); 081 double meanTruth = Util.weightedMean(trueArray, sufficientStats.weights, sufficientStats.n); 082 double denominator = 0.0; 083 for (int i = 0; i < sufficientStats.n; i++) { 084 double difference = trueArray[i] - meanTruth; 085 float currWeight = sufficientStats.weights[i]; 086 denominator += currWeight * difference * difference; 087 } 088 return 1.0 - (numerator / denominator); 089 } 090 091 /** 092 * Calculates the RMSE based on the supplied statistics. 093 * @param target The regression dimension or average to target. 094 * @param sufficientStats The sufficient statistics. 095 * @return The RMSE of the predictions. 096 */ 097 public static double rmse(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) { 098 return compute(target, sufficientStats, RegressionMetrics::rmse); 099 } 100 101 /** 102 * Calculates the RMSE based on the supplied statistics for a single dimension. 103 * @param variable The regression dimension to target. 104 * @param sufficientStats The sufficient statistics. 105 * @return The RMSE of the predictions. 106 */ 107 public static double rmse(Regressor variable, RegressionSufficientStatistics sufficientStats) { 108 String varname = variable.getNames()[0]; 109 double sumSqErr = sufficientStats.sumSquaredError.get(varname).doubleValue(); 110 return Math.sqrt(sumSqErr / sufficientStats.weightSum); 111 } 112 113 /** 114 * Calculates the Mean Absolute Error based on the supplied statistics. 115 * @param target The regression dimension or average to target. 116 * @param sufficientStats The sufficient statistics. 117 * @return The MAE of the predictions. 118 */ 119 public static double mae(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) { 120 return compute(target, sufficientStats, RegressionMetrics::mae); 121 } 122 123 /** 124 * Calculates the Mean Absolute Error based on the supplied statistics for a single dimension. 125 * @param variable The regression dimension to target. 126 * @param sufficientStats The sufficient statistics. 127 * @return The MAE of the predictions. 128 */ 129 public static double mae(Regressor variable, RegressionSufficientStatistics sufficientStats) { 130 String varname = variable.getNames()[0]; 131 double sumAbsErr = sufficientStats.sumAbsoluteError.get(varname).doubleValue(); 132 return sumAbsErr / sufficientStats.weightSum; 133 } 134 135 /** 136 * Calculates the explained variance based on the supplied statistics. 137 * @param target The regression dimension or average to target. 138 * @param sufficientStats The sufficient statistics. 139 * @return The explained variance of the truth given the predictions. 140 */ 141 public static double explainedVariance(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) { 142 return compute(target, sufficientStats, RegressionMetrics::explainedVariance); 143 } 144 145 /** 146 * Calculates the explained variance based on the supplied statistics for a single dimension. 147 * @param variable The regression dimension to target. 148 * @param sufficientStats The sufficient statistics. 149 * @return The explained variance of the truth given the predictions. 150 */ 151 public static double explainedVariance(Regressor variable, RegressionSufficientStatistics sufficientStats) { 152 String varname = variable.getNames()[0]; 153 double[] trueArray = sufficientStats.trueValues.get(varname); 154 double[] predictedArray = sufficientStats.predictedValues.get(varname); 155 156 double meanDifference = 0.0; 157 for (int i = 0; i < sufficientStats.n; i++) { 158 meanDifference += sufficientStats.weights[i] * (trueArray[i] - predictedArray[i]); 159 } 160 meanDifference /= sufficientStats.weightSum; 161 double meanTruth = Util.weightedMean(trueArray, sufficientStats.weights, sufficientStats.n); 162 163 double numerator = 0d; 164 double denominator = 0d; 165 for (int i = 0; i < sufficientStats.n; i++) { 166 float weight = sufficientStats.weights[i]; 167 double variance = trueArray[i] - predictedArray[i] - meanDifference; 168 numerator += weight * variance * variance; 169 double difference = trueArray[i] - meanTruth; 170 denominator += weight * difference * difference; 171 } 172 173 return 1d - (numerator/denominator); 174 } 175 176 /** 177 * Computes the supplied function on the supplied metric target. 178 * @param target The metric target. 179 * @param sufficientStats The sufficient statistics. 180 * @param impl The function to apply. 181 * @return The metric value. 182 */ 183 private static double compute(MetricTarget<Regressor> target, 184 RegressionSufficientStatistics sufficientStats, 185 BiFunction<Regressor, RegressionSufficientStatistics, Double> impl) { 186 if (target.getOutputTarget().isPresent()) { 187 return impl.apply(target.getOutputTarget().get(), sufficientStats); 188 } else if (target.getAverageTarget().isPresent()) { 189 EvaluationMetric.Average averageType = target.getAverageTarget().get(); 190 switch (averageType) { 191 case MACRO: 192 double accumulator = 0.0; 193 for (Regressor r : sufficientStats.domain.getDomain()) { 194 accumulator += impl.apply(r,sufficientStats); 195 } 196 return accumulator / sufficientStats.domain.size(); 197 case MICRO: 198 throw new IllegalStateException("Micro averages are not supported for regression metrics."); 199 default: 200 throw new IllegalStateException("Unexpected average type " + averageType); 201 } 202 } else { 203 throw new IllegalStateException("MetricTarget without target."); 204 } 205 } 206 207}