001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.evaluation;
018
019import org.tribuo.evaluation.metrics.EvaluationMetric;
020import org.tribuo.evaluation.metrics.MetricTarget;
021import org.tribuo.regression.Regressor;
022import org.tribuo.util.Util;
023
024import java.util.function.BiFunction;
025import java.util.function.ToDoubleBiFunction;
026
027/**
028 * An enum of the default {@link RegressionMetric}s supported by the multi-dimensional regression
029 * evaluation package.
030 * <p>
031 * The metrics treat each regressed dimension independently.
032 */
033public enum RegressionMetrics {
034
035    /**
036     * Calculates the R^2 of the predictions.
037     */
038    R2((target, context) -> RegressionMetrics.r2(target, context.getMemo())),
039    /**
040     * Calculates the Root Mean Squared Error of the predictions.
041     */
042    RMSE((target, context) -> RegressionMetrics.rmse(target, context.getMemo())),
043    /**
044     * Calculates the Mean Absolute Error of the predictions.
045     */
046    MAE((target, context) -> RegressionMetrics.mae(target, context.getMemo())),
047    /**
048     * Calculates the Explained Variance of the predictions.
049     */
050    EV((target, context) -> RegressionMetrics.explainedVariance(target, context.getMemo()));
051
052    private final ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl;
053    RegressionMetrics(ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl) {
054        this.impl = impl;
055    }
056
057    RegressionMetric forTarget(MetricTarget<Regressor> target) {
058        return new RegressionMetric(target, this.name(), this.impl);
059    }
060
061    /**
062     * Calculates R^2 based on the supplied statistics.
063     * @param target The regression dimension or average to target.
064     * @param sufficientStats The sufficient statistics.
065     * @return The R^2 value of the predictions.
066     */
067    public static double r2(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
068        return compute(target, sufficientStats, RegressionMetrics::r2);
069    }
070
071    /**
072     * Calculates R^2 based on the supplied statistics for a single dimension.
073     * @param variable The regression dimension.
074     * @param sufficientStats The sufficient statistics.
075     * @return The R^2 value of the predictions.
076     */
077    public static double r2(Regressor variable, RegressionSufficientStatistics sufficientStats) {
078        String varname = variable.getNames()[0];
079        double[] trueArray = sufficientStats.trueValues.get(varname);
080        double numerator = sufficientStats.sumSquaredError.get(varname).doubleValue();
081        double meanTruth = Util.weightedMean(trueArray, sufficientStats.weights, sufficientStats.n);
082        double denominator = 0.0;
083        for (int i = 0; i < sufficientStats.n; i++) {
084            double difference = trueArray[i] - meanTruth;
085            float currWeight = sufficientStats.weights[i];
086            denominator += currWeight * difference * difference;
087        }
088        return 1.0 - (numerator / denominator);
089    }
090
091    /**
092     * Calculates the RMSE based on the supplied statistics.
093     * @param target The regression dimension or average to target.
094     * @param sufficientStats The sufficient statistics.
095     * @return The RMSE of the predictions.
096     */
097    public static double rmse(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
098        return compute(target, sufficientStats, RegressionMetrics::rmse);
099    }
100
101    /**
102     * Calculates the RMSE based on the supplied statistics for a single dimension.
103     * @param variable The regression dimension to target.
104     * @param sufficientStats The sufficient statistics.
105     * @return The RMSE of the predictions.
106     */
107    public static double rmse(Regressor variable, RegressionSufficientStatistics sufficientStats) {
108        String varname = variable.getNames()[0];
109        double sumSqErr = sufficientStats.sumSquaredError.get(varname).doubleValue();
110        return Math.sqrt(sumSqErr / sufficientStats.weightSum);
111    }
112
113    /**
114     * Calculates the Mean Absolute Error based on the supplied statistics.
115     * @param target The regression dimension or average to target.
116     * @param sufficientStats The sufficient statistics.
117     * @return The MAE of the predictions.
118     */
119    public static double mae(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
120        return compute(target, sufficientStats, RegressionMetrics::mae);
121    }
122
123    /**
124     * Calculates the Mean Absolute Error based on the supplied statistics for a single dimension.
125     * @param variable The regression dimension to target.
126     * @param sufficientStats The sufficient statistics.
127     * @return The MAE of the predictions.
128     */
129    public static double mae(Regressor variable, RegressionSufficientStatistics sufficientStats) {
130        String varname = variable.getNames()[0];
131        double sumAbsErr = sufficientStats.sumAbsoluteError.get(varname).doubleValue();
132        return sumAbsErr / sufficientStats.weightSum;
133    }
134
135    /**
136     * Calculates the explained variance based on the supplied statistics.
137     * @param target The regression dimension or average to target.
138     * @param sufficientStats The sufficient statistics.
139     * @return The explained variance of the truth given the predictions.
140     */
141    public static double explainedVariance(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
142        return compute(target, sufficientStats, RegressionMetrics::explainedVariance);
143    }
144
145    /**
146     * Calculates the explained variance based on the supplied statistics for a single dimension.
147     * @param variable The regression dimension to target.
148     * @param sufficientStats The sufficient statistics.
149     * @return The explained variance of the truth given the predictions.
150     */
151    public static double explainedVariance(Regressor variable, RegressionSufficientStatistics sufficientStats) {
152        String varname = variable.getNames()[0];
153        double[] trueArray = sufficientStats.trueValues.get(varname);
154        double[] predictedArray = sufficientStats.predictedValues.get(varname);
155
156        double meanDifference = 0.0;
157        for (int i = 0; i < sufficientStats.n; i++) {
158            meanDifference += sufficientStats.weights[i] * (trueArray[i] - predictedArray[i]);
159        }
160        meanDifference /= sufficientStats.weightSum;
161        double meanTruth = Util.weightedMean(trueArray, sufficientStats.weights, sufficientStats.n);
162
163        double numerator = 0d;
164        double denominator = 0d;
165        for (int i = 0; i < sufficientStats.n; i++) {
166            float weight = sufficientStats.weights[i];
167            double variance = trueArray[i] - predictedArray[i] - meanDifference;
168            numerator += weight * variance * variance;
169            double difference = trueArray[i] - meanTruth;
170            denominator += weight * difference * difference;
171        }
172
173        return 1d - (numerator/denominator);
174    }
175
176    /**
177     * Computes the supplied function on the supplied metric target.
178     * @param target The metric target.
179     * @param sufficientStats The sufficient statistics.
180     * @param impl The function to apply.
181     * @return The metric value.
182     */
183    private static double compute(MetricTarget<Regressor> target,
184                                  RegressionSufficientStatistics sufficientStats,
185                                  BiFunction<Regressor, RegressionSufficientStatistics, Double> impl) {
186        if (target.getOutputTarget().isPresent()) {
187            return impl.apply(target.getOutputTarget().get(), sufficientStats);
188        } else if (target.getAverageTarget().isPresent()) {
189            EvaluationMetric.Average averageType = target.getAverageTarget().get();
190            switch (averageType) {
191                case MACRO:
192                    double accumulator = 0.0;
193                    for (Regressor r : sufficientStats.domain.getDomain()) {
194                        accumulator += impl.apply(r,sufficientStats);
195                    }
196                    return accumulator / sufficientStats.domain.size();
197                case MICRO:
198                    throw new IllegalStateException("Micro averages are not supported for regression metrics.");
199                default:
200                    throw new IllegalStateException("Unexpected average type " + averageType);
201            }
202        } else {
203            throw new IllegalStateException("MetricTarget without target.");
204        }
205    }
206
207}