001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impurity; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021 022import java.util.List; 023 024/** 025 * Measures the mean absolute error over a set of inputs. 026 * <p> 027 * Used to calculate the impurity of a regression node. 028 */ 029public class MeanAbsoluteError implements RegressorImpurity { 030 031 @Override 032 public double impurity(float[] targets, float[] weights) { 033 float weightedSum = 0.0f; 034 float weightSum = 0.0f; 035 for (int i = 0; i < targets.length; i++) { 036 weightedSum += targets[i]*weights[i]; 037 weightSum += weights[i]; 038 } 039 float mean = weightedSum / weightSum; 040 041 float absoluteError = 0.0f; 042 043 for (int i = 0; i < targets.length; i++) { 044 float error = Math.abs(mean - targets[i]); 045 absoluteError += error*weights[i]; 046 } 047 return absoluteError / weightSum; 048 } 049 050 @Override 051 public ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights) { 052 if (indicesLength == 1) { 053 return new ImpurityTuple(0.0f,weights[indices[0]]); 054 } else { 055 float weightedSum = 0.0f; 056 float weightSum = 0.0f; 057 for (int i = 0; i < indicesLength; i++) { 058 int idx = indices[i]; 059 weightedSum += targets[idx]*weights[idx]; 060 weightSum += weights[idx]; 061 } 062 float mean = weightedSum / weightSum; 063 064 float absoluteError = 0.0f; 065 066 for (int i = 0; i < indicesLength; i++) { 067 int idx = indices[i]; 068 float error = Math.abs(mean - targets[idx]); 069 absoluteError += error*weights[idx]; 070 } 071 return new ImpurityTuple(absoluteError,weightSum); 072 } 073 } 074 075 @Override 076 public ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights) { 077 float weightedSum = 0.0f; 078 float weightSum = 0.0f; 079 for (int[] curIndices : indices) { 080 for (int i = 0; i < curIndices.length; i++) { 081 int idx = curIndices[i]; 082 weightedSum += targets[idx] * weights[idx]; 083 weightSum += weights[idx]; 084 } 085 } 086 float mean = weightedSum / weightSum; 087 088 float absoluteError = 0.0f; 089 090 for (int[] curIndices : indices) { 091 for (int i = 0; i < curIndices.length; i++) { 092 int idx = curIndices[i]; 093 float error = Math.abs(mean - targets[idx]); 094 absoluteError += error * weights[idx]; 095 } 096 } 097 return new ImpurityTuple(absoluteError,weightSum); 098 } 099 100 @Override 101 public String toString() { 102 return "MeanAbsoluteError"; 103 } 104 105 @Override 106 public ConfiguredObjectProvenance getProvenance() { 107 return new ConfiguredObjectProvenanceImpl(this,"RegressorImpurity"); 108 } 109}