001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.evaluation;
018
019import com.oracle.labs.mlrg.olcut.util.MutableDouble;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.Prediction;
022import org.tribuo.regression.RegressionFactory;
023import org.tribuo.regression.Regressor;
024
025import java.util.Arrays;
026import java.util.LinkedHashMap;
027import java.util.List;
028import java.util.Map;
029
030/**
031 * The sufficient statistics for regression metrics (i.e., each prediction and each true value).
032 */
033public final class RegressionSufficientStatistics {
034
035    final int n;
036    final ImmutableOutputInfo<Regressor> domain;
037
038    final Map<String, MutableDouble> sumAbsoluteError = new LinkedHashMap<>();
039    final Map<String, MutableDouble> sumSquaredError = new LinkedHashMap<>();
040
041    final Map<String, double[]> predictedValues = new LinkedHashMap<>();
042    final Map<String, double[]> trueValues = new LinkedHashMap<>();
043
044    // if useExampleWeights is false, all weights are set to 1.0
045    final float[] weights;
046
047    // if useExampleWeights is false, weightSum == n
048    final float weightSum;
049
050    public RegressionSufficientStatistics(ImmutableOutputInfo<Regressor> domain, List<Prediction<Regressor>> predictions, boolean useExampleWeights) {
051        this.domain = domain;
052        this.n = predictions.size();
053        this.weights = initWeights(predictions, useExampleWeights);
054        for (Regressor e : domain.getDomain()) {
055            String name = e.getNames()[0];
056            sumAbsoluteError.put(name,new MutableDouble());
057            sumSquaredError.put(name,new MutableDouble());
058            predictedValues.put(name,new double[this.n]);
059            trueValues.put(name,new double[this.n]);
060        }
061        this.weightSum = tabulate(predictions);
062    }
063
064    private float tabulate(List<Prediction<Regressor>> predictions) {
065        float weightSum = 0f;
066
067        for (int i = 0; i < this.n; i++) {
068            Prediction<Regressor> prediction = predictions.get(i);
069
070            float weight = weights[i];
071            weightSum += weight;
072
073            Regressor pred = prediction.getOutput();
074            Regressor truth = prediction.getExample().getOutput();
075            if (truth.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
076                throw new IllegalArgumentException("The sentinel Unknown Regressor was used as a ground truth output at prediction number " + i);
077            } else if (pred.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
078                throw new IllegalArgumentException("The sentinel Unknown Regressor was predicted by the model at prediction number " + i);
079            }
080
081            for (int j = 0; j < truth.size(); j++) {
082                String name = truth.getNames()[j];
083                double trueValue = truth.getValues()[j];
084                double predValue = pred.getValues()[j];
085
086                double diff = trueValue - predValue;
087                sumAbsoluteError.get(name).increment(weight*Math.abs(diff));
088                sumSquaredError.get(name).increment(weight*diff*diff);
089
090                trueValues.get(name)[i] = trueValue;
091                predictedValues.get(name)[i] = predValue;
092            }
093        }
094        return weightSum;
095    }
096
097    private static float[] initWeights(List<Prediction<Regressor>> predictions, boolean useExampleWeights) {
098        float[] weights = new float[predictions.size()];
099        if (useExampleWeights) {
100            for (int i = 0; i < predictions.size(); i++) {
101                weights[i] = predictions.get(i).getExample().getWeight();
102            }
103        } else {
104            Arrays.fill(weights,1.0f);
105        }
106        return weights;
107    }
108
109}