001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.evaluation; 018 019import com.oracle.labs.mlrg.olcut.util.MutableDouble; 020import org.tribuo.ImmutableOutputInfo; 021import org.tribuo.Prediction; 022import org.tribuo.regression.RegressionFactory; 023import org.tribuo.regression.Regressor; 024 025import java.util.Arrays; 026import java.util.LinkedHashMap; 027import java.util.List; 028import java.util.Map; 029 030/** 031 * The sufficient statistics for regression metrics (i.e., each prediction and each true value). 032 */ 033public final class RegressionSufficientStatistics { 034 035 final int n; 036 final ImmutableOutputInfo<Regressor> domain; 037 038 final Map<String, MutableDouble> sumAbsoluteError = new LinkedHashMap<>(); 039 final Map<String, MutableDouble> sumSquaredError = new LinkedHashMap<>(); 040 041 final Map<String, double[]> predictedValues = new LinkedHashMap<>(); 042 final Map<String, double[]> trueValues = new LinkedHashMap<>(); 043 044 // if useExampleWeights is false, all weights are set to 1.0 045 final float[] weights; 046 047 // if useExampleWeights is false, weightSum == n 048 final float weightSum; 049 050 public RegressionSufficientStatistics(ImmutableOutputInfo<Regressor> domain, List<Prediction<Regressor>> predictions, boolean useExampleWeights) { 051 this.domain = domain; 052 this.n = predictions.size(); 053 this.weights = initWeights(predictions, useExampleWeights); 054 for (Regressor e : domain.getDomain()) { 055 String name = e.getNames()[0]; 056 sumAbsoluteError.put(name,new MutableDouble()); 057 sumSquaredError.put(name,new MutableDouble()); 058 predictedValues.put(name,new double[this.n]); 059 trueValues.put(name,new double[this.n]); 060 } 061 this.weightSum = tabulate(predictions); 062 } 063 064 private float tabulate(List<Prediction<Regressor>> predictions) { 065 float weightSum = 0f; 066 067 for (int i = 0; i < this.n; i++) { 068 Prediction<Regressor> prediction = predictions.get(i); 069 070 float weight = weights[i]; 071 weightSum += weight; 072 073 Regressor pred = prediction.getOutput(); 074 Regressor truth = prediction.getExample().getOutput(); 075 if (truth.equals(RegressionFactory.UNKNOWN_REGRESSOR)) { 076 throw new IllegalArgumentException("The sentinel Unknown Regressor was used as a ground truth output at prediction number " + i); 077 } else if (pred.equals(RegressionFactory.UNKNOWN_REGRESSOR)) { 078 throw new IllegalArgumentException("The sentinel Unknown Regressor was predicted by the model at prediction number " + i); 079 } 080 081 for (int j = 0; j < truth.size(); j++) { 082 String name = truth.getNames()[j]; 083 double trueValue = truth.getValues()[j]; 084 double predValue = pred.getValues()[j]; 085 086 double diff = trueValue - predValue; 087 sumAbsoluteError.get(name).increment(weight*Math.abs(diff)); 088 sumSquaredError.get(name).increment(weight*diff*diff); 089 090 trueValues.get(name)[i] = trueValue; 091 predictedValues.get(name)[i] = predValue; 092 } 093 } 094 return weightSum; 095 } 096 097 private static float[] initWeights(List<Prediction<Regressor>> predictions, boolean useExampleWeights) { 098 float[] weights = new float[predictions.size()]; 099 if (useExampleWeights) { 100 for (int i = 0; i < predictions.size(); i++) { 101 weights[i] = predictions.get(i).getExample().getWeight(); 102 } 103 } else { 104 Arrays.fill(weights,1.0f); 105 } 106 return weights; 107 } 108 109}