001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.dtree.impurity;
018
019import com.oracle.labs.mlrg.olcut.config.Configurable;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
022
023import java.util.Map;
024
025/**
026 * Calculates a tree impurity score based on label counts, weighted label counts or a probability distribution.
027 */
028public interface LabelImpurity extends Configurable, Provenancable<ConfiguredObjectProvenance> {
029
030    /**
031     * Calculates the impurity, assuming it's input is a normalized probability distribution.
032     * @param input The input probability distribution.
033     * @return The impurity.
034     */
035    public double impurityNormed(double[] input);
036
037    /**
038     * Calculates the impurity assuming the inputs are weighted counts normalizing by their sum.
039     * @param input The input counts.
040     * @return The impurity.
041     */
042    default public double impurityWeighted(double[] input) {
043        double[] prob = new double[input.length];
044
045        double sum = 0.0;
046        for (double i : input) {
047            sum += i;
048        }
049
050        for (int i = 0; i < input.length; i++) {
051            prob[i] = input[i] / sum;
052        }
053
054        return sum*impurityNormed(prob);
055    }
056
057    /**
058     * Calculates the impurity assuming the inputs are counts.
059     * @param input The input counts.
060     * @return The impurity.
061     */
062    default public double impurity(double[] input) {
063        double[] prob = new double[input.length];
064
065        double sum = 0.0;
066        for (double i : input) {
067            sum += i;
068        }
069
070        for (int i = 0; i < input.length; i++) {
071            prob[i] = input[i] / sum;
072        }
073
074        return impurityNormed(prob);
075    }
076
077    /**
078     * Calculates the impurity assuming the input are weighted counts, normalizing by their sum.
079     * @param input The input counts.
080     * @return The impurity.
081     */
082    default public double impurityWeighted(float[] input) {
083        double[] prob = new double[input.length];
084
085        double sum = 0.0;
086        for (int i = 0; i < input.length; i++) {
087            float f = input[i];
088            sum += f;
089        }
090
091        for (int i = 0; i < input.length; i++) {
092            prob[i] = input[i] / sum;
093        }
094
095        return sum*impurityNormed(prob);
096    }
097
098    /**
099     * Calculates the impurity assuming the input are fractional counts.
100     * @param input The input counts.
101     * @return The impurity.
102     */
103    default public double impurity(float[] input) {
104        double[] prob = new double[input.length];
105
106        double sum = 0.0;
107        for (double i : input) {
108            sum += i;
109        }
110
111        for (int i = 0; i < input.length; i++) {
112            prob[i] = input[i] / sum;
113        }
114
115        return impurityNormed(prob);
116    }
117
118    /**
119     * Calculates the impurity assuming the input are counts.
120     * @param input The input counts.
121     * @return The impurity.
122     */
123    default public double impurity(int[] input) {
124        double[] prob = new double[input.length];
125
126        int sum = 0;
127        for (int i : input) {
128            sum += i;
129        }
130
131        double sumFloat = sum;
132        for (int i = 0; i < input.length; i++) {
133            prob[i] = input[i] / sumFloat;
134        }
135
136        return impurityNormed(prob);
137    }
138
139    /**
140     * Takes a {@link Map} for weighted counts. Normalizes into a probability distribution before calling impurityNormed(double[]).
141     * @param counts A map from instances to weighted counts.
142     * @return The impurity score.
143     */
144    default public double impurity(Map<String,Double> counts) {
145        double[] prob = new double[counts.size()];
146
147        double sum = 0.0;
148        int i = 0;
149        for (Map.Entry<String,Double> e : counts.entrySet()) {
150            sum += e.getValue();
151            prob[i] = e.getValue();
152            i++;
153        }
154
155        for (int j = 0; j < prob.length; j++) {
156            prob[j] /= sum;
157        }
158
159        return impurityNormed(prob);
160    }
161
162}