001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.impl.CachedPair; 021import org.tribuo.util.infotheory.impl.CachedTriple; 022import org.tribuo.util.infotheory.impl.PairDistribution; 023import org.tribuo.util.infotheory.impl.TripleDistribution; 024import org.tribuo.util.infotheory.impl.WeightCountTuple; 025import org.tribuo.util.infotheory.impl.WeightedPairDistribution; 026import org.tribuo.util.infotheory.impl.WeightedTripleDistribution; 027 028import java.util.ArrayList; 029import java.util.LinkedHashMap; 030import java.util.List; 031import java.util.Map; 032import java.util.Map.Entry; 033import java.util.logging.Level; 034import java.util.logging.Logger; 035 036/** 037 * A class of (discrete) weighted information theoretic functions. Gives warnings if 038 * there are insufficient samples to estimate the quantities accurately. 039 * <p> 040 * Defaults to log_2, so returns values in bits. 041 * <p> 042 * All functions expect that the element types have well defined equals and 043 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined 044 * if this is not true. 045 */ 046public final class WeightedInformationTheory { 047 private static final Logger logger = Logger.getLogger(WeightedInformationTheory.class.getName()); 048 049 public static final double SAMPLES_RATIO = 5.0; 050 public static final int DEFAULT_MAP_SIZE = 20; 051 public static final double LOG_2 = Math.log(2); 052 public static final double LOG_E = Math.log(Math.E); 053 054 /** 055 * Sets the base of the logarithm used in the information theoretic calculations. 056 * For LOG_2 the unit is "bit", for LOG_E the unit is "nat". 057 */ 058 public static double LOG_BASE = LOG_2; 059 060 /** 061 * Chooses which variable is the one with associated weights. 062 */ 063 public enum VariableSelector { 064 FIRST, SECOND, THIRD 065 } 066 067 /** 068 * Private constructor, only has static methods. 069 */ 070 private WeightedInformationTheory() {} 071 072 /** 073 * Calculates the discrete weighted joint mutual information, using 074 * histogram probability estimators. Arrays must be the same length. 075 * @param <T1> Type contained in the first array. 076 * @param <T2> Type contained in the second array. 077 * @param <T3> Type contained in the target array. 078 * @param first An array of values. 079 * @param second Another array of values. 080 * @param target Target array of values. 081 * @param weights Array of weight values. 082 * @return The mutual information I(first,second;joint) 083 */ 084 public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target, List<Double> weights) { 085 WeightedTripleDistribution<T1, T2, T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, target, weights); 086 087 return jointMI(tripleRV); 088 } 089 090 public static <T1,T2,T3> double jointMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) { 091 Map<CachedTriple<T1,T2,T3>, WeightCountTuple> jointCount = tripleRV.getJointCount(); 092 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = tripleRV.getABCount(); 093 Map<T3,WeightCountTuple> cCount = tripleRV.getCCount(); 094 095 double vectorLength = tripleRV.count; 096 double jmi = 0.0; 097 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 098 double jointCurCount = e.getValue().count; 099 double jointCurWeight = e.getValue().weight; 100 double prob = jointCurCount / vectorLength; 101 CachedPair<T1,T2> pair = e.getKey().getAB(); 102 double abCurCount = abCount.get(pair).count; 103 double cCurCount = cCount.get(e.getKey().getC()).count; 104 105 jmi += jointCurWeight * prob * Math.log((vectorLength*jointCurCount)/(abCurCount*cCurCount)); 106 } 107 jmi /= LOG_BASE; 108 109 double stateRatio = vectorLength / jointCount.size(); 110 if (stateRatio < SAMPLES_RATIO) { 111 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}", new Object[]{jmi, stateRatio}); 112 } 113 114 return jmi; 115 } 116 117 public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs){ 118 Double boxedWeight; 119 double vecLength = rv.count; 120 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 121 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 122 Map<T3,MutableLong> cCount = rv.getCCount(); 123 124 double jmi = 0.0; 125 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 126 double jointCurCount = e.getValue().doubleValue(); 127 double prob = jointCurCount / vecLength; 128 CachedPair<T1,T2> pair = new CachedPair<>(e.getKey().getA(),e.getKey().getB()); 129 double abCurCount = abCount.get(pair).doubleValue(); 130 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 131 132 double weight = 1.0; 133 switch (vs) { 134 case FIRST: 135 boxedWeight = weights.get(e.getKey().getA()); 136 weight = boxedWeight == null ? 1.0 : boxedWeight; 137 break; 138 case SECOND: 139 boxedWeight = weights.get(e.getKey().getB()); 140 weight = boxedWeight == null ? 1.0 : boxedWeight; 141 break; 142 case THIRD: 143 boxedWeight = weights.get(e.getKey().getC()); 144 weight = boxedWeight == null ? 1.0 : boxedWeight; 145 break; 146 } 147 148 jmi += weight * prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount)); 149 } 150 jmi /= LOG_BASE; 151 152 double stateRatio = vecLength / jointCount.size(); 153 if (stateRatio < SAMPLES_RATIO) { 154 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()}); 155 } 156 157 return jmi; 158 } 159 160 /** 161 * Calculates the discrete weighted conditional mutual information, using 162 * histogram probability estimators. Arrays must be the same length. 163 * @param <T1> Type contained in the first array. 164 * @param <T2> Type contained in the second array. 165 * @param <T3> Type contained in the condition array. 166 * @param first An array of values. 167 * @param second Another array of values. 168 * @param condition Array to condition upon. 169 * @param weights Array of weight values. 170 * @return The conditional mutual information I(first;second|condition) 171 */ 172 public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition, List<Double> weights) { 173 if ((first.size() == second.size()) && (first.size() == condition.size()) && (first.size() == weights.size())) { 174 WeightedTripleDistribution<T1,T2,T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, condition, weights); 175 176 return conditionalMI(tripleRV); 177 } else { 178 throw new IllegalArgumentException("Weighted Conditional Mutual Information requires four vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size() + ", weights.size() = "+ weights.size()); 179 } 180 } 181 182 public static <T1,T2,T3> double conditionalMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) { 183 Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = tripleRV.getJointCount(); 184 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = tripleRV.getACCount(); 185 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = tripleRV.getBCCount(); 186 Map<T3,WeightCountTuple> cCount = tripleRV.getCCount(); 187 188 double vectorLength = tripleRV.count; 189 double cmi = 0.0; 190 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 191 double weight = e.getValue().weight; 192 double jointCurCount = e.getValue().count; 193 double prob = jointCurCount / vectorLength; 194 CachedPair<T1,T3> acPair = e.getKey().getAC(); 195 CachedPair<T2,T3> bcPair = e.getKey().getBC(); 196 double acCurCount = acCount.get(acPair).count; 197 double bcCurCount = bcCount.get(bcPair).count; 198 double cCurCount = cCount.get(e.getKey().getC()).count; 199 200 cmi += weight * prob * Math.log((cCurCount*jointCurCount)/(acCurCount*bcCurCount)); 201 } 202 cmi /= LOG_BASE; 203 204 double stateRatio = vectorLength / jointCount.size(); 205 if (stateRatio < SAMPLES_RATIO) { 206 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 207 } 208 209 return cmi; 210 } 211 212 public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs) { 213 Double boxedWeight; 214 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 215 Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount(); 216 Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount(); 217 Map<T3,MutableLong> cCount = rv.getCCount(); 218 219 double vectorLength = rv.count; 220 double cmi = 0.0; 221 for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) { 222 double jointCurCount = e.getValue().doubleValue(); 223 double prob = jointCurCount / vectorLength; 224 CachedPair<T1, T3> acPair = new CachedPair<>(e.getKey().getA(), e.getKey().getC()); 225 CachedPair<T2, T3> bcPair = new CachedPair<>(e.getKey().getB(), e.getKey().getC()); 226 double acCurCount = acCount.get(acPair).doubleValue(); 227 double bcCurCount = bcCount.get(bcPair).doubleValue(); 228 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 229 230 double weight = 1.0; 231 switch (vs) { 232 case FIRST: 233 boxedWeight = weights.get(e.getKey().getA()); 234 weight = boxedWeight == null ? 1.0 : boxedWeight; 235 break; 236 case SECOND: 237 boxedWeight = weights.get(e.getKey().getB()); 238 weight = boxedWeight == null ? 1.0 : boxedWeight; 239 break; 240 case THIRD: 241 boxedWeight = weights.get(e.getKey().getC()); 242 weight = boxedWeight == null ? 1.0 : boxedWeight; 243 break; 244 } 245 cmi += weight * prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount)); 246 } 247 cmi /= LOG_BASE; 248 249 double stateRatio = vectorLength / jointCount.size(); 250 if (stateRatio < SAMPLES_RATIO) { 251 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 252 } 253 254 return cmi; 255 } 256 257 /** 258 * Calculates the discrete weighted mutual information, using histogram 259 * probability estimators. 260 * <p> 261 * Arrays must be the same length. 262 * @param <T1> Type of the first array 263 * @param <T2> Type of the second array 264 * @param first An array of values 265 * @param second Another array of values 266 * @param weights Array of weight values. 267 * @return The mutual information I(first;Second) 268 */ 269 public static <T1,T2> double mi(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) { 270 if ((first.size() == second.size()) && (first.size() == weights.size())) { 271 WeightedPairDistribution<T1,T2> countPair = WeightedPairDistribution.constructFromLists(first,second,weights); 272 return mi(countPair); 273 } else { 274 throw new IllegalArgumentException("Weighted Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 275 } 276 } 277 278 public static <T1,T2> double mi(WeightedPairDistribution<T1,T2> jointDist) { 279 double vectorLength = jointDist.count; 280 double mi = 0.0; 281 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = jointDist.getJointCounts(); 282 Map<T1,WeightCountTuple> firstCountDist = jointDist.getFirstCount(); 283 Map<T2,WeightCountTuple> secondCountDist = jointDist.getSecondCount(); 284 285 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 286 double weight = e.getValue().weight; 287 double jointCount = e.getValue().count; 288 double prob = jointCount / vectorLength; 289 double firstCount = firstCountDist.get(e.getKey().getA()).count; 290 double secondCount = secondCountDist.get(e.getKey().getB()).count; 291 292 mi += weight * prob * Math.log((vectorLength*jointCount)/(firstCount*secondCount)); 293 } 294 mi /= LOG_BASE; 295 296 double stateRatio = vectorLength / countDist.size(); 297 if (stateRatio < SAMPLES_RATIO) { 298 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 299 } 300 301 return mi; 302 } 303 304 public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist, Map<?,Double> weights, VariableSelector vs) { 305 if (vs == VariableSelector.THIRD) { 306 throw new IllegalArgumentException("MI only has two variables"); 307 } 308 Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts; 309 Map<T1,MutableLong> firstCountDist = pairDist.firstCount; 310 Map<T2,MutableLong> secondCountDist = pairDist.secondCount; 311 312 Double boxedWeight; 313 double vectorLength = pairDist.count; 314 double mi = 0.0; 315 boolean error = false; 316 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 317 double jointCount = e.getValue().doubleValue(); 318 double prob = jointCount / vectorLength; 319 double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue(); 320 double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue(); 321 322 double top = vectorLength * jointCount; 323 double bottom = firstProb * secondProb; 324 double ratio = top/bottom; 325 double logRatio = Math.log(ratio); 326 327 if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) { 328 logger.log(Level.WARNING, "State = " + e.getKey().toString()); 329 logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio); 330 error = true; 331 } 332 333 double weight = 1.0; 334 switch (vs) { 335 case FIRST: 336 boxedWeight = weights.get(e.getKey().getA()); 337 weight = boxedWeight == null ? 1.0 : boxedWeight; 338 break; 339 case SECOND: 340 boxedWeight = weights.get(e.getKey().getB()); 341 weight = boxedWeight == null ? 1.0 : boxedWeight; 342 break; 343 default: 344 throw new IllegalArgumentException("VariableSelector.THIRD not allowed in a two variable calculation."); 345 } 346 mi += weight * prob * logRatio; 347 //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb)); 348 } 349 mi /= LOG_BASE; 350 351 double stateRatio = vectorLength / countDist.size(); 352 if (stateRatio < SAMPLES_RATIO) { 353 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 354 } 355 356 if (error) { 357 logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found")); 358 } 359 360 return mi; 361 } 362 363 /** 364 * Calculates the Shannon/Guiasu weighted joint entropy of two arrays, 365 * using histogram probability estimators. 366 * <p> 367 * Arrays must be same length. 368 * @param <T1> Type of the first array. 369 * @param <T2> Type of the second array. 370 * @param first An array of values. 371 * @param second Another array of values. 372 * @param weights Array of weight values. 373 * @return The entropy H(first,second) 374 */ 375 public static <T1,T2> double jointEntropy(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) { 376 if ((first.size() == second.size()) && (first.size() == weights.size())) { 377 double vectorLength = first.size(); 378 double jointEntropy = 0.0; 379 380 WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(first,second,weights); 381 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts(); 382 383 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 384 double prob = e.getValue().count / vectorLength; 385 double weight = e.getValue().weight; 386 387 jointEntropy -= weight * prob * Math.log(prob); 388 } 389 jointEntropy /= LOG_BASE; 390 391 double stateRatio = vectorLength / countDist.size(); 392 if (stateRatio < SAMPLES_RATIO) { 393 logger.log(Level.INFO, "Weighted Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio}); 394 } 395 396 return jointEntropy; 397 } else { 398 throw new IllegalArgumentException("Weighted Joint Entropy requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 399 } 400 } 401 402 /** 403 * Calculates the discrete Shannon/Guiasu Weighted Conditional Entropy of 404 * two arrays, using histogram probability estimators. 405 * <p> 406 * Arrays must be the same length. 407 * @param <T1> Type of the first array. 408 * @param <T2> Type of the second array. 409 * @param vector The main array of values. 410 * @param condition The array to condition on. 411 * @param weights Array of weight values. 412 * @return The weighted conditional entropy H_w(vector|condition). 413 */ 414 public static <T1,T2> double weightedConditionalEntropy(ArrayList<T1> vector, ArrayList<T2> condition, ArrayList<Double> weights) { 415 if ((vector.size() == condition.size()) && (vector.size() == weights.size())) { 416 double vectorLength = vector.size(); 417 double condEntropy = 0.0; 418 419 WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(vector,condition,weights); 420 Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts(); 421 Map<T2,WeightCountTuple> conditionCountDist = pairDist.getSecondCount(); 422 423 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 424 double prob = e.getValue().count / vectorLength; 425 double condProb = conditionCountDist.get(e.getKey().getB()).count / vectorLength; 426 double weight = e.getValue().weight; 427 428 condEntropy -= weight * prob * Math.log(prob/condProb); 429 } 430 condEntropy /= LOG_BASE; 431 432 double stateRatio = vectorLength / countDist.size(); 433 if (stateRatio < SAMPLES_RATIO) { 434 logger.log(Level.INFO, "Weighted Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio}); 435 } 436 437 return condEntropy; 438 } else { 439 throw new IllegalArgumentException("Weighted Conditional Entropy requires three vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size() + ", weights.size() = " + weights.size()); 440 } 441 } 442 443 /** 444 * Calculates the discrete Shannon/Guiasu Weighted Entropy, using histogram 445 * probability estimators. 446 * @param <T> Type of the array. 447 * @param vector The array of values. 448 * @param weights Array of weight values. 449 * @return The weighted entropy H_w(vector). 450 */ 451 public static <T> double weightedEntropy(ArrayList<T> vector, ArrayList<Double> weights) { 452 if (vector.size() == weights.size()) { 453 double vectorLength = vector.size(); 454 double entropy = 0.0; 455 456 Map<T,WeightCountTuple> countDist = calculateWeightedCountDist(vector,weights); 457 for (Entry<T,WeightCountTuple> e : countDist.entrySet()) { 458 long count = e.getValue().count; 459 double weight = e.getValue().weight; 460 double prob = count / vectorLength; 461 entropy -= weight * prob * Math.log(prob); 462 } 463 entropy /= LOG_BASE; 464 465 double stateRatio = vectorLength / countDist.size(); 466 if (stateRatio < SAMPLES_RATIO) { 467 logger.log(Level.INFO, "Weighted Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio}); 468 } 469 470 return entropy; 471 } else { 472 throw new IllegalArgumentException("Weighted Entropy requires two vectors the same length. vector.size() = " + vector.size() + ",weights.size() = " + weights.size()); 473 } 474 } 475 476 /** 477 * Generate the counts for a single vector. 478 * @param <T> The type inside the vector. 479 * @param vector An array of values. 480 * @param weights The array of weight values. 481 * @return A HashMap from states of T to Pairs of count and total weight for that state. 482 */ 483 public static <T> Map<T,WeightCountTuple> calculateWeightedCountDist(ArrayList<T> vector, ArrayList<Double> weights) { 484 Map<T,WeightCountTuple> dist = new LinkedHashMap<>(DEFAULT_MAP_SIZE); 485 for (int i = 0; i < vector.size(); i++) { 486 T e = vector.get(i); 487 Double weight = weights.get(i); 488 WeightCountTuple curVal = dist.computeIfAbsent(e,(k) -> new WeightCountTuple()); 489 curVal.count += 1; 490 curVal.weight += weight; 491 } 492 493 normaliseWeights(dist); 494 495 return dist; 496 } 497 498 /** 499 * Normalizes the weights in the map, i.e., divides each weight by it's count. 500 * @param map The map to normalize. 501 * @param <T> The type of the variable that was counted. 502 */ 503 public static <T> void normaliseWeights(Map<T,WeightCountTuple> map) { 504 for (Entry<T,WeightCountTuple> e : map.entrySet()) { 505 WeightCountTuple tuple = e.getValue(); 506 tuple.weight /= tuple.count; 507 } 508 } 509 510}