001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.impl.CachedPair; 021import org.tribuo.util.infotheory.impl.CachedTriple; 022import org.tribuo.util.infotheory.impl.PairDistribution; 023import org.tribuo.util.infotheory.impl.Row; 024import org.tribuo.util.infotheory.impl.RowList; 025import org.tribuo.util.infotheory.impl.TripleDistribution; 026import org.apache.commons.math3.distribution.ChiSquaredDistribution; 027 028import java.util.HashMap; 029import java.util.List; 030import java.util.Map; 031import java.util.Map.Entry; 032import java.util.Set; 033import java.util.logging.Level; 034import java.util.logging.Logger; 035import java.util.stream.DoubleStream; 036import java.util.stream.Stream; 037 038/** 039 * A class of (discrete) information theoretic functions. Gives warnings if 040 * there are insufficient samples to estimate the quantities accurately. 041 * <p> 042 * Defaults to log_2, so returns values in bits. 043 * <p> 044 * All functions expect that the element types have well defined equals and 045 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined 046 * if this is not true. 047 */ 048public final class InformationTheory { 049 private static final Logger logger = Logger.getLogger(InformationTheory.class.getName()); 050 051 public static final double SAMPLES_RATIO = 5.0; 052 public static final int DEFAULT_MAP_SIZE = 20; 053 public static final double LOG_2 = Math.log(2); 054 public static final double LOG_E = Math.log(Math.E); 055 056 /** 057 * Sets the base of the logarithm used in the information theoretic calculations. 058 * For LOG_2 the unit is "bit", for LOG_E the unit is "nat". 059 */ 060 public static double LOG_BASE = LOG_2; 061 062 /** 063 * Private constructor, only has static methods. 064 */ 065 private InformationTheory() {} 066 067 /** 068 * Calculates the mutual information between the two sets of random variables. 069 * @param first The first set of random variables. 070 * @param second The second set of random variables. 071 * @param <T1> The first type. 072 * @param <T2> The second type. 073 * @return The mutual information I(first;second). 074 */ 075 public static <T1,T2> double mi(Set<List<T1>> first, Set<List<T2>> second) { 076 List<Row<T1>> firstList = new RowList<>(first); 077 List<Row<T2>> secondList = new RowList<>(second); 078 079 return mi(firstList,secondList); 080 } 081 082 /** 083 * Calculates the conditional mutual information between first and second conditioned on the set. 084 * @param first A sample from the first random variable. 085 * @param second A sample from the second random variable. 086 * @param condition A sample from the conditioning set of random variables. 087 * @param <T1> The first type. 088 * @param <T2> The second type. 089 * @param <T3> The third type. 090 * @return The conditional mutual information I(first;second|condition). 091 */ 092 public static <T1,T2,T3> double cmi(List<T1> first, List<T2> second, Set<List<T3>> condition) { 093 if (condition.isEmpty()) { 094 //logger.log(Level.INFO,"Empty conditioning set"); 095 return mi(first,second); 096 } else { 097 List<Row<T3>> conditionList = new RowList<>(condition); 098 099 return conditionalMI(first,second,conditionList); 100 } 101 } 102 103 /** 104 * Calculates the GTest statistics for the input variables conditioned on the set. 105 * @param first A sample from the first random variable. 106 * @param second A sample from the second random variable. 107 * @param condition A sample from the conditioning set of random variables. 108 * @param <T1> The first type. 109 * @param <T2> The second type. 110 * @param <T3> The third type. 111 * @return The GTest statistics. 112 */ 113 public static <T1,T2,T3> GTestStatistics gTest(List<T1> first, List<T2> second, Set<List<T3>> condition) { 114 ScoreStateCountTuple tuple; 115 if (condition == null) { 116 //logger.log(Level.INFO,"Null conditioning set"); 117 tuple = innerMI(first,second); 118 } else if (condition.isEmpty()) { 119 //logger.log(Level.INFO,"Empty conditioning set"); 120 tuple = innerMI(first,second); 121 } else { 122 List<Row<T3>> conditionList = new RowList<>(condition); 123 124 tuple = innerConditionalMI(first,second,conditionList); 125 } 126 double gMetric = 2 * second.size() * tuple.score; 127 ChiSquaredDistribution dist = new ChiSquaredDistribution(tuple.stateCount); 128 double prob = dist.cumulativeProbability(gMetric); 129 GTestStatistics test = new GTestStatistics(gMetric,tuple.stateCount,prob); 130 return test; 131 } 132 133 /** 134 * Calculates the discrete Shannon joint mutual information, using 135 * histogram probability estimators. Arrays must be the same length. 136 * @param <T1> Type contained in the first array. 137 * @param <T2> Type contained in the second array. 138 * @param <T3> Type contained in the target array. 139 * @param first An array of values. 140 * @param second Another array of values. 141 * @param target Target array of values. 142 * @return The mutual information I(first,second;joint) 143 */ 144 public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target) { 145 if ((first.size() == second.size()) && (first.size() == target.size())) { 146 TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,target); 147 return jointMI(tripleRV); 148 } else { 149 throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", target.size() = " + target.size()); 150 } 151 } 152 153 /** 154 * Calculates the discrete Shannon joint mutual information, using 155 * histogram probability estimators. Arrays must be the same length. 156 * @param <T1> Type contained in the first array. 157 * @param <T2> Type contained in the second array. 158 * @param <T3> Type contained in the target array. 159 * @param rv The random variable to calculate the joint mi of 160 * @return The mutual information I(first,second;joint) 161 */ 162 public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv) { 163 double vecLength = rv.count; 164 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 165 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 166 Map<T3,MutableLong> cCount = rv.getCCount(); 167 168 double jmi = 0.0; 169 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 170 double jointCurCount = e.getValue().doubleValue(); 171 double prob = jointCurCount / vecLength; 172 CachedPair<T1,T2> pair = e.getKey().getAB(); 173 double abCurCount = abCount.get(pair).doubleValue(); 174 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 175 176 jmi += prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount)); 177 } 178 jmi /= LOG_BASE; 179 180 double stateRatio = vecLength / jointCount.size(); 181 if (stateRatio < SAMPLES_RATIO) { 182 logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()}); 183 } 184 185 return jmi; 186 } 187 188 /** 189 * Calculates the conditional mutual information. If flipped == true, then calculates I(T1;T3|T2), otherwise calculates I(T1;T2|T3). 190 * @param <T1> The type of the first argument. 191 * @param <T2> The type of the second argument. 192 * @param <T3> The type of the third argument. 193 * @param rv The random variable. 194 * @param flipped If true then the second element is the conditional variable, otherwise the third element is. 195 * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable. 196 */ 197 private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1,T2,T3> rv, boolean flipped) { 198 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount(); 199 Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount(); 200 Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount(); 201 Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount(); 202 Map<T2,MutableLong> bCount = rv.getBCount(); 203 Map<T3,MutableLong> cCount = rv.getCCount(); 204 205 double vectorLength = rv.count; 206 double cmi = 0.0; 207 if (flipped) { 208 for (Entry<CachedTriple<T1,T2,T3>, MutableLong> e : jointCount.entrySet()) { 209 double jointCurCount = e.getValue().doubleValue(); 210 double prob = jointCurCount / vectorLength; 211 CachedPair<T1,T2> abPair = e.getKey().getAB(); 212 CachedPair<T2,T3> bcPair = e.getKey().getBC(); 213 double abCurCount = abCount.get(abPair).doubleValue(); 214 double bcCurCount = bcCount.get(bcPair).doubleValue(); 215 double bCurCount = bCount.get(e.getKey().getB()).doubleValue(); 216 217 cmi += prob * Math.log((bCurCount * jointCurCount) / (abCurCount * bcCurCount)); 218 } 219 } else { 220 for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) { 221 double jointCurCount = e.getValue().doubleValue(); 222 double prob = jointCurCount / vectorLength; 223 CachedPair<T1, T3> acPair = e.getKey().getAC(); 224 CachedPair<T2, T3> bcPair = e.getKey().getBC(); 225 double acCurCount = acCount.get(acPair).doubleValue(); 226 double bcCurCount = bcCount.get(bcPair).doubleValue(); 227 double cCurCount = cCount.get(e.getKey().getC()).doubleValue(); 228 229 cmi += prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount)); 230 } 231 } 232 cmi /= LOG_BASE; 233 234 double stateRatio = vectorLength / jointCount.size(); 235 if (stateRatio < SAMPLES_RATIO) { 236 logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio}); 237 } 238 239 return new ScoreStateCountTuple(cmi,jointCount.size()); 240 } 241 242 /** 243 * Calculates the conditional mutual information, I(T1;T2|T3). 244 * @param <T1> The type of the first argument. 245 * @param <T2> The type of the second argument. 246 * @param <T3> The type of the third argument. 247 * @param first The first random variable. 248 * @param second The second random variable. 249 * @param condition The conditioning random variable. 250 * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable. 251 */ 252 private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(List<T1> first, List<T2> second, List<T3> condition) { 253 if ((first.size() == second.size()) && (first.size() == condition.size())) { 254 TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,condition); 255 256 return innerConditionalMI(tripleRV,false); 257 } else { 258 throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size()); 259 } 260 } 261 262 /** 263 * Calculates the discrete Shannon conditional mutual information, using 264 * histogram probability estimators. Arrays must be the same length. 265 * @param <T1> Type contained in the first array. 266 * @param <T2> Type contained in the second array. 267 * @param <T3> Type contained in the condition array. 268 * @param first An array of values. 269 * @param second Another array of values. 270 * @param condition Array to condition upon. 271 * @return The conditional mutual information I(first;second|condition) 272 */ 273 public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition) { 274 return innerConditionalMI(first,second,condition).score; 275 } 276 277 /** 278 * Calculates the discrete Shannon conditional mutual information, using 279 * histogram probability estimators. Note this calculates I(T1;T2|T3). 280 * @param <T1> Type of the first variable. 281 * @param <T2> Type of the second variable. 282 * @param <T3> Type of the condition variable. 283 * @param rv The triple random variable of the three inputs. 284 * @return The conditional mutual information I(first;second|condition) 285 */ 286 public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv) { 287 return innerConditionalMI(rv,false).score; 288 } 289 290 /** 291 * Calculates the discrete Shannon conditional mutual information, using 292 * histogram probability estimators. Note this calculates I(T1;T3|T2). 293 * @param <T1> Type of the first variable. 294 * @param <T2> Type of the condition variable. 295 * @param <T3> Type of the second variable. 296 * @param rv The triple random variable of the three inputs. 297 * @return The conditional mutual information I(first;second|condition) 298 */ 299 public static <T1,T2,T3> double conditionalMIFlipped(TripleDistribution<T1,T2,T3> rv) { 300 return innerConditionalMI(rv,true).score; 301 } 302 303 /** 304 * Calculates the mutual information from a joint random variable. 305 * @param pairDist The joint distribution. 306 * @param <T1> The first type. 307 * @param <T2> The second type. 308 * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable. 309 */ 310 private static <T1,T2> ScoreStateCountTuple innerMI(PairDistribution<T1,T2> pairDist) { 311 Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts; 312 Map<T1,MutableLong> firstCountDist = pairDist.firstCount; 313 Map<T2,MutableLong> secondCountDist = pairDist.secondCount; 314 315 double vectorLength = pairDist.count; 316 double mi = 0.0; 317 boolean error = false; 318 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 319 double jointCount = e.getValue().doubleValue(); 320 double prob = jointCount / vectorLength; 321 double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue(); 322 double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue(); 323 324 double top = vectorLength * jointCount; 325 double bottom = firstProb * secondProb; 326 double ratio = top/bottom; 327 double logRatio = Math.log(ratio); 328 329 if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) { 330 logger.log(Level.WARNING, "State = " + e.getKey().toString()); 331 logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio); 332 error = true; 333 } 334 mi += prob * logRatio; 335 //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb)); 336 } 337 mi /= LOG_BASE; 338 339 double stateRatio = vectorLength / countDist.size(); 340 if (stateRatio < SAMPLES_RATIO) { 341 logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio}); 342 } 343 344 if (error) { 345 logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found")); 346 } 347 348 return new ScoreStateCountTuple(mi,countDist.size()); 349 } 350 351 /** 352 * Calculates the mutual information between the two lists. 353 * @param first The first list. 354 * @param second The second list. 355 * @param <T1> The first type. 356 * @param <T2> The second type. 357 * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable. 358 */ 359 private static <T1,T2> ScoreStateCountTuple innerMI(List<T1> first, List<T2> second) { 360 if (first.size() == second.size()) { 361 PairDistribution<T1,T2> pairDist = PairDistribution.constructFromLists(first, second); 362 363 return innerMI(pairDist); 364 } else { 365 throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 366 } 367 } 368 369 /** 370 * Calculates the discrete Shannon mutual information, using histogram 371 * probability estimators. Arrays must be the same length. 372 * @param <T1> Type of the first array 373 * @param <T2> Type of the second array 374 * @param first An array of values 375 * @param second Another array of values 376 * @return The mutual information I(first;second) 377 */ 378 public static <T1,T2> double mi(List<T1> first, List<T2> second) { 379 return innerMI(first,second).score; 380 } 381 382 /** 383 * Calculates the discrete Shannon mutual information, using histogram 384 * probability estimators. 385 * @param <T1> Type of the first variable 386 * @param <T2> Type of the second variable 387 * @param pairDist PairDistribution for the two variables. 388 * @return The mutual information I(first;second) 389 */ 390 public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist) { 391 return innerMI(pairDist).score; 392 } 393 394 /** 395 * Calculates the Shannon joint entropy of two arrays, using histogram 396 * probability estimators. Arrays must be same length. 397 * @param <T1> Type of the first array. 398 * @param <T2> Type of the second array. 399 * @param first An array of values. 400 * @param second Another array of values. 401 * @return The entropy H(first,second) 402 */ 403 public static <T1,T2> double jointEntropy(List<T1> first, List<T2> second) { 404 if (first.size() == second.size()) { 405 double vectorLength = first.size(); 406 double jointEntropy = 0.0; 407 408 PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(first,second); 409 Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts; 410 411 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 412 double prob = e.getValue().doubleValue() / vectorLength; 413 414 jointEntropy -= prob * Math.log(prob); 415 } 416 jointEntropy /= LOG_BASE; 417 418 double stateRatio = vectorLength / countDist.size(); 419 if (stateRatio < SAMPLES_RATIO) { 420 logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio}); 421 } 422 423 return jointEntropy; 424 } else { 425 throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 426 } 427 } 428 429 /** 430 * Calculates the discrete Shannon conditional entropy of two arrays, using 431 * histogram probability estimators. Arrays must be the same length. 432 * @param <T1> Type of the first array. 433 * @param <T2> Type of the second array. 434 * @param vector The main array of values. 435 * @param condition The array to condition on. 436 * @return The conditional entropy H(vector|condition). 437 */ 438 public static <T1,T2> double conditionalEntropy(List<T1> vector, List<T2> condition) { 439 if (vector.size() == condition.size()) { 440 double vectorLength = vector.size(); 441 double condEntropy = 0.0; 442 443 PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(vector,condition); 444 Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts; 445 Map<T2,MutableLong> conditionCountDist = countPair.secondCount; 446 447 for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) { 448 double prob = e.getValue().doubleValue() / vectorLength; 449 double condProb = conditionCountDist.get(e.getKey().getB()).doubleValue() / vectorLength; 450 451 condEntropy -= prob * Math.log(prob/condProb); 452 } 453 condEntropy /= LOG_BASE; 454 455 double stateRatio = vectorLength / countDist.size(); 456 if (stateRatio < SAMPLES_RATIO) { 457 logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio}); 458 } 459 460 return condEntropy; 461 } else { 462 throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size()); 463 } 464 } 465 466 /** 467 * Calculates the discrete Shannon entropy, using histogram probability 468 * estimators. 469 * @param <T> Type of the array. 470 * @param vector The array of values. 471 * @return The entropy H(vector). 472 */ 473 public static <T> double entropy(List<T> vector) { 474 double vectorLength = vector.size(); 475 double entropy = 0.0; 476 477 Map<T,Long> countDist = calculateCountDist(vector); 478 for (Entry<T,Long> e : countDist.entrySet()) { 479 double prob = e.getValue() / vectorLength; 480 entropy -= prob * Math.log(prob); 481 } 482 entropy /= LOG_BASE; 483 484 double stateRatio = vectorLength / countDist.size(); 485 if (stateRatio < SAMPLES_RATIO) { 486 logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio}); 487 } 488 489 return entropy; 490 } 491 492 /** 493 * Generate the counts for a single vector. 494 * @param <T> The type inside the vector. 495 * @param vector An array of values. 496 * @return A HashMap from states of T to counts. 497 */ 498 public static <T> Map<T,Long> calculateCountDist(List<T> vector) { 499 HashMap<T,Long> countDist = new HashMap<>(DEFAULT_MAP_SIZE); 500 for (T e : vector) { 501 Long curCount = countDist.getOrDefault(e,0L); 502 curCount += 1; 503 countDist.put(e, curCount); 504 } 505 506 return countDist; 507 } 508 509 /** 510 * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is 511 * an element of the same probability distribution. 512 * @param vector The probability distribution. 513 * @return The entropy. 514 */ 515 public static double calculateEntropy(Stream<Double> vector) { 516 return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).reduce(0.0, Double::sum); 517 } 518 519 /** 520 * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is 521 * an element of the same probability distribution. 522 * @param vector The probability distribution. 523 * @return The entropy. 524 */ 525 public static double calculateEntropy(DoubleStream vector) { 526 return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).sum(); 527 } 528 529 /** 530 * A tuple of the information theoretic value, along with the number of 531 * states in the random variable. 532 */ 533 private static class ScoreStateCountTuple { 534 public final double score; 535 public final int stateCount; 536 537 public ScoreStateCountTuple(double score, int stateCount) { 538 this.score = score; 539 this.stateCount = stateCount; 540 } 541 542 @Override 543 public String toString() { 544 return "ScoreStateCount(score=" + score + ",stateCount=" + stateCount + ")"; 545 } 546 } 547 548 /** 549 * An immutable named tuple containing the statistics from a G test. 550 */ 551 public static final class GTestStatistics { 552 public final double gStatistic; 553 public final int numStates; 554 public final double probability; 555 556 public GTestStatistics(double gStatistic, int numStates, double probability) { 557 this.gStatistic = gStatistic; 558 this.numStates = numStates; 559 this.probability = probability; 560 } 561 562 @Override 563 public String toString() { 564 return "GTest(statistic="+gStatistic+",probability="+probability+",numStates="+numStates+")"; 565 } 566 } 567} 568