001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.dtree.impurity; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022 023import java.util.Map; 024 025/** 026 * Calculates a tree impurity score based on label counts, weighted label counts or a probability distribution. 027 */ 028public interface LabelImpurity extends Configurable, Provenancable<ConfiguredObjectProvenance> { 029 030 /** 031 * Calculates the impurity, assuming it's input is a normalized probability distribution. 032 * @param input The input probability distribution. 033 * @return The impurity. 034 */ 035 public double impurityNormed(double[] input); 036 037 /** 038 * Calculates the impurity assuming the inputs are weighted counts normalizing by their sum. 039 * @param input The input counts. 040 * @return The impurity. 041 */ 042 default public double impurityWeighted(double[] input) { 043 double[] prob = new double[input.length]; 044 045 double sum = 0.0; 046 for (double i : input) { 047 sum += i; 048 } 049 050 for (int i = 0; i < input.length; i++) { 051 prob[i] = input[i] / sum; 052 } 053 054 return sum*impurityNormed(prob); 055 } 056 057 /** 058 * Calculates the impurity assuming the inputs are counts. 059 * @param input The input counts. 060 * @return The impurity. 061 */ 062 default public double impurity(double[] input) { 063 double[] prob = new double[input.length]; 064 065 double sum = 0.0; 066 for (double i : input) { 067 sum += i; 068 } 069 070 for (int i = 0; i < input.length; i++) { 071 prob[i] = input[i] / sum; 072 } 073 074 return impurityNormed(prob); 075 } 076 077 /** 078 * Calculates the impurity assuming the input are weighted counts, normalizing by their sum. 079 * @param input The input counts. 080 * @return The impurity. 081 */ 082 default public double impurityWeighted(float[] input) { 083 double[] prob = new double[input.length]; 084 085 double sum = 0.0; 086 for (int i = 0; i < input.length; i++) { 087 float f = input[i]; 088 sum += f; 089 } 090 091 for (int i = 0; i < input.length; i++) { 092 prob[i] = input[i] / sum; 093 } 094 095 return sum*impurityNormed(prob); 096 } 097 098 /** 099 * Calculates the impurity assuming the input are fractional counts. 100 * @param input The input counts. 101 * @return The impurity. 102 */ 103 default public double impurity(float[] input) { 104 double[] prob = new double[input.length]; 105 106 double sum = 0.0; 107 for (double i : input) { 108 sum += i; 109 } 110 111 for (int i = 0; i < input.length; i++) { 112 prob[i] = input[i] / sum; 113 } 114 115 return impurityNormed(prob); 116 } 117 118 /** 119 * Calculates the impurity assuming the input are counts. 120 * @param input The input counts. 121 * @return The impurity. 122 */ 123 default public double impurity(int[] input) { 124 double[] prob = new double[input.length]; 125 126 int sum = 0; 127 for (int i : input) { 128 sum += i; 129 } 130 131 double sumFloat = sum; 132 for (int i = 0; i < input.length; i++) { 133 prob[i] = input[i] / sumFloat; 134 } 135 136 return impurityNormed(prob); 137 } 138 139 /** 140 * Takes a {@link Map} for weighted counts. Normalizes into a probability distribution before calling impurityNormed(double[]). 141 * @param counts A map from instances to weighted counts. 142 * @return The impurity score. 143 */ 144 default public double impurity(Map<String,Double> counts) { 145 double[] prob = new double[counts.size()]; 146 147 double sum = 0.0; 148 int i = 0; 149 for (Map.Entry<String,Double> e : counts.entrySet()) { 150 sum += e.getValue(); 151 prob[i] = e.getValue(); 152 i++; 153 } 154 155 for (int j = 0; j < prob.length; j++) { 156 prob[j] /= sum; 157 } 158 159 return impurityNormed(prob); 160 } 161 162}