001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impurity; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021 022import java.util.List; 023 024/** 025 * Measures the mean squared error over a set of inputs. 026 * <p> 027 * Used to calculate the impurity of a regression node. 028 */ 029public class MeanSquaredError implements RegressorImpurity { 030 031 @Override 032 public double impurity(float[] targets, float[] weights) { 033 float weightedSum = 0.0f; 034 float weightSum = 0.0f; 035 for (int i = 0; i < targets.length; i++) { 036 weightedSum += targets[i]*weights[i]; 037 weightSum += weights[i]; 038 } 039 float mean = weightedSum / weightSum; 040 041 float squaredError = 0.0f; 042 043 for (int i = 0; i < targets.length; i++) { 044 float error = mean - targets[i]; 045 squaredError += error*error*weights[i]; 046 } 047 return squaredError / weightSum; 048 } 049 050 @Override 051 public ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights) { 052 if (indicesLength == 1) { 053 return new ImpurityTuple(0.0f,weights[indices[0]]); 054 } else { 055 float weightedSquaredSum = 0.0f; 056 float weightedSum = 0.0f; 057 float weightSum = 0.0f; 058 for (int i = 0; i < indicesLength; i++) { 059 int idx = indices[i]; 060 float weight = weights[idx]; 061 float target = targets[idx]; 062 float curVal = target * weight; 063 weightedSum += curVal; 064 weightedSquaredSum += curVal * target; 065 weightSum += weight; 066 } 067 float mean = weightedSum / weightSum; 068 return new ImpurityTuple((weightedSquaredSum / weightSum) - (mean*mean),weightSum); 069 } 070 } 071 072 @Override 073 public ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights) { 074 float weightedSquaredSum = 0.0f; 075 float weightedSum = 0.0f; 076 float weightSum = 0.0f; 077 for (int[] curIndices : indices) { 078 for (int i = 0; i < curIndices.length; i++) { 079 int idx = curIndices[i]; 080 float weight = weights[idx]; 081 float target = targets[idx]; 082 float curVal = target * weight; 083 weightedSum += curVal; 084 weightedSquaredSum += curVal * target; 085 weightSum += weight; 086 } 087 } 088 089 float mean = weightedSum / weightSum; 090 return new ImpurityTuple((weightedSquaredSum / weightSum) - (mean*mean),weightSum); 091 } 092 093 @Override 094 public String toString() { 095 return "MeanSquaredError"; 096 } 097 098 @Override 099 public ConfiguredObjectProvenance getProvenance() { 100 return new ConfiguredObjectProvenanceImpl(this,"RegressorImpurity"); 101 } 102}