001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impurity;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021
022import java.util.List;
023
024/**
025 * Measures the mean squared error over a set of inputs.
026 * <p>
027 * Used to calculate the impurity of a regression node.
028 */
029public class MeanSquaredError implements RegressorImpurity {
030
031    @Override
032    public double impurity(float[] targets, float[] weights) {
033        float weightedSum = 0.0f;
034        float weightSum = 0.0f;
035        for (int i = 0; i < targets.length; i++) {
036            weightedSum += targets[i]*weights[i];
037            weightSum += weights[i];
038        }
039        float mean = weightedSum / weightSum;
040
041        float squaredError = 0.0f;
042
043        for (int i = 0; i < targets.length; i++) {
044            float error = mean - targets[i];
045            squaredError += error*error*weights[i];
046        }
047        return squaredError / weightSum;
048    }
049
050    @Override
051    public ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights) {
052        if (indicesLength == 1) {
053            return new ImpurityTuple(0.0f,weights[indices[0]]);
054        } else {
055            float weightedSquaredSum = 0.0f;
056            float weightedSum = 0.0f;
057            float weightSum = 0.0f;
058            for (int i = 0; i < indicesLength; i++) {
059                int idx = indices[i];
060                float weight = weights[idx];
061                float target = targets[idx];
062                float curVal = target * weight;
063                weightedSum += curVal;
064                weightedSquaredSum += curVal * target;
065                weightSum += weight;
066            }
067            float mean = weightedSum / weightSum;
068            return new ImpurityTuple((weightedSquaredSum / weightSum) - (mean*mean),weightSum);
069        }
070    }
071
072    @Override
073    public ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights) {
074        float weightedSquaredSum = 0.0f;
075        float weightedSum = 0.0f;
076        float weightSum = 0.0f;
077        for (int[] curIndices : indices) {
078            for (int i = 0; i < curIndices.length; i++) {
079                int idx = curIndices[i];
080                float weight = weights[idx];
081                float target = targets[idx];
082                float curVal = target * weight;
083                weightedSum += curVal;
084                weightedSquaredSum += curVal * target;
085                weightSum += weight;
086            }
087        }
088
089        float mean = weightedSum / weightSum;
090        return new ImpurityTuple((weightedSquaredSum / weightSum) - (mean*mean),weightSum);
091    }
092
093    @Override
094    public String toString() {
095        return "MeanSquaredError";
096    }
097
098    @Override
099    public ConfiguredObjectProvenance getProvenance() {
100        return new ConfiguredObjectProvenanceImpl(this,"RegressorImpurity");
101    }
102}