001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.impl.CachedPair;
021import org.tribuo.util.infotheory.impl.CachedTriple;
022import org.tribuo.util.infotheory.impl.PairDistribution;
023import org.tribuo.util.infotheory.impl.Row;
024import org.tribuo.util.infotheory.impl.RowList;
025import org.tribuo.util.infotheory.impl.TripleDistribution;
026import org.apache.commons.math3.distribution.ChiSquaredDistribution;
027
028import java.util.HashMap;
029import java.util.List;
030import java.util.Map;
031import java.util.Map.Entry;
032import java.util.Set;
033import java.util.logging.Level;
034import java.util.logging.Logger;
035import java.util.stream.DoubleStream;
036import java.util.stream.Stream;
037
038/**
039 * A class of (discrete) information theoretic functions. Gives warnings if
040 * there are insufficient samples to estimate the quantities accurately.
041 * <p>
042 * Defaults to log_2, so returns values in bits.
043 * <p>
044 * All functions expect that the element types have well defined equals and
045 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined
046 * if this is not true.
047 */
048public final class InformationTheory {
049    private static final Logger logger = Logger.getLogger(InformationTheory.class.getName());
050
051    public static final double SAMPLES_RATIO = 5.0;
052    public static final int DEFAULT_MAP_SIZE = 20;
053    public static final double LOG_2 = Math.log(2);
054    public static final double LOG_E = Math.log(Math.E);
055
056    /**
057     * Sets the base of the logarithm used in the information theoretic calculations.
058     * For LOG_2 the unit is "bit", for LOG_E the unit is "nat".
059     */
060    public static double LOG_BASE = LOG_2;
061
062    /**
063     * Private constructor, only has static methods.
064     */
065    private InformationTheory() {}
066
067    /**
068     * Calculates the mutual information between the two sets of random variables.
069     * @param first The first set of random variables.
070     * @param second The second set of random variables.
071     * @param <T1> The first type.
072     * @param <T2> The second type.
073     * @return The mutual information I(first;second).
074     */
075    public static <T1,T2> double mi(Set<List<T1>> first, Set<List<T2>> second) {
076        List<Row<T1>> firstList = new RowList<>(first);
077        List<Row<T2>> secondList = new RowList<>(second);
078
079        return mi(firstList,secondList);
080    }
081
082    /**
083     * Calculates the conditional mutual information between first and second conditioned on the set.
084     * @param first A sample from the first random variable.
085     * @param second A sample from the second random variable.
086     * @param condition A sample from the conditioning set of random variables.
087     * @param <T1> The first type.
088     * @param <T2> The second type.
089     * @param <T3> The third type.
090     * @return The conditional mutual information I(first;second|condition).
091     */
092    public static <T1,T2,T3> double cmi(List<T1> first, List<T2> second, Set<List<T3>> condition) {
093        if (condition.isEmpty()) {
094            //logger.log(Level.INFO,"Empty conditioning set");
095            return mi(first,second);
096        } else {
097            List<Row<T3>> conditionList = new RowList<>(condition);
098        
099            return conditionalMI(first,second,conditionList);
100        }
101    }
102
103    /**
104     * Calculates the GTest statistics for the input variables conditioned on the set.
105     * @param first A sample from the first random variable.
106     * @param second A sample from the second random variable.
107     * @param condition A sample from the conditioning set of random variables.
108     * @param <T1> The first type.
109     * @param <T2> The second type.
110     * @param <T3> The third type.
111     * @return The GTest statistics.
112     */
113    public static <T1,T2,T3> GTestStatistics gTest(List<T1> first, List<T2> second, Set<List<T3>> condition) {
114        ScoreStateCountTuple tuple;
115        if (condition == null) {
116            //logger.log(Level.INFO,"Null conditioning set");
117            tuple = innerMI(first,second);
118        } else if (condition.isEmpty()) {
119            //logger.log(Level.INFO,"Empty conditioning set");
120            tuple = innerMI(first,second);
121        } else {
122            List<Row<T3>> conditionList = new RowList<>(condition);
123        
124            tuple = innerConditionalMI(first,second,conditionList);
125        }
126        double gMetric = 2 * second.size() * tuple.score;
127        ChiSquaredDistribution dist = new ChiSquaredDistribution(tuple.stateCount);
128        double prob = dist.cumulativeProbability(gMetric);
129        GTestStatistics test = new GTestStatistics(gMetric,tuple.stateCount,prob);
130        return test;
131    }
132
133    /**
134     * Calculates the discrete Shannon joint mutual information, using
135     * histogram probability estimators. Arrays must be the same length.
136     * @param <T1> Type contained in the first array.
137     * @param <T2> Type contained in the second array.
138     * @param <T3> Type contained in the target array.
139     * @param first An array of values.
140     * @param second Another array of values.
141     * @param target Target array of values.
142     * @return The mutual information I(first,second;joint)
143     */
144    public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target) {
145        if ((first.size() == second.size()) && (first.size() == target.size())) {
146            TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,target);
147            return jointMI(tripleRV);
148        } else {
149            throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", target.size() = " + target.size());
150        }
151    }
152
153    /**
154     * Calculates the discrete Shannon joint mutual information, using
155     * histogram probability estimators. Arrays must be the same length.
156     * @param <T1> Type contained in the first array.
157     * @param <T2> Type contained in the second array.
158     * @param <T3> Type contained in the target array.
159     * @param rv The random variable to calculate the joint mi of
160     * @return The mutual information I(first,second;joint)
161     */
162    public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv) {
163        double vecLength = rv.count;
164        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
165        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
166        Map<T3,MutableLong> cCount = rv.getCCount();
167
168        double jmi = 0.0;
169        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
170            double jointCurCount = e.getValue().doubleValue();
171            double prob = jointCurCount / vecLength;
172            CachedPair<T1,T2> pair = e.getKey().getAB();
173            double abCurCount = abCount.get(pair).doubleValue();
174            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
175
176            jmi += prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount));
177        }
178        jmi /= LOG_BASE;
179        
180        double stateRatio = vecLength / jointCount.size();
181        if (stateRatio < SAMPLES_RATIO) {
182            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()});
183        }
184
185        return jmi;
186    }
187
188    /**
189     * Calculates the conditional mutual information. If flipped == true, then calculates I(T1;T3|T2), otherwise calculates I(T1;T2|T3).
190     * @param <T1> The type of the first argument.
191     * @param <T2> The type of the second argument.
192     * @param <T3> The type of the third argument.
193     * @param rv The random variable.
194     * @param flipped If true then the second element is the conditional variable, otherwise the third element is.
195     * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable.
196     */
197    private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1,T2,T3> rv, boolean flipped) {
198        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
199        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
200        Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount();
201        Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount();
202        Map<T2,MutableLong> bCount = rv.getBCount();
203        Map<T3,MutableLong> cCount = rv.getCCount();
204
205        double vectorLength = rv.count;
206        double cmi = 0.0;
207        if (flipped) {
208            for (Entry<CachedTriple<T1,T2,T3>, MutableLong> e : jointCount.entrySet()) {
209                double jointCurCount = e.getValue().doubleValue();
210                double prob = jointCurCount / vectorLength;
211                CachedPair<T1,T2> abPair = e.getKey().getAB();
212                CachedPair<T2,T3> bcPair = e.getKey().getBC();
213                double abCurCount = abCount.get(abPair).doubleValue();
214                double bcCurCount = bcCount.get(bcPair).doubleValue();
215                double bCurCount = bCount.get(e.getKey().getB()).doubleValue();
216
217                cmi += prob * Math.log((bCurCount * jointCurCount) / (abCurCount * bcCurCount));
218            }
219        } else {
220            for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
221                double jointCurCount = e.getValue().doubleValue();
222                double prob = jointCurCount / vectorLength;
223                CachedPair<T1, T3> acPair = e.getKey().getAC();
224                CachedPair<T2, T3> bcPair = e.getKey().getBC();
225                double acCurCount = acCount.get(acPair).doubleValue();
226                double bcCurCount = bcCount.get(bcPair).doubleValue();
227                double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
228
229                cmi += prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount));
230            }
231        }
232        cmi /= LOG_BASE;
233
234        double stateRatio = vectorLength / jointCount.size();
235        if (stateRatio < SAMPLES_RATIO) {
236            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
237        }
238        
239        return new ScoreStateCountTuple(cmi,jointCount.size());
240    }
241
242    /**
243     * Calculates the conditional mutual information, I(T1;T2|T3).
244     * @param <T1> The type of the first argument.
245     * @param <T2> The type of the second argument.
246     * @param <T3> The type of the third argument.
247     * @param first The first random variable.
248     * @param second The second random variable.
249     * @param condition The conditioning random variable.
250     * @return A ScoreStateCountTuple containing the conditional mutual information and the number of states in the joint random variable.
251     */
252    private static <T1,T2,T3> ScoreStateCountTuple innerConditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
253        if ((first.size() == second.size()) && (first.size() == condition.size())) {
254            TripleDistribution<T1,T2,T3> tripleRV = TripleDistribution.constructFromLists(first,second,condition);
255
256            return innerConditionalMI(tripleRV,false);
257        } else {
258            throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size());
259        }
260    }
261    
262    /**
263     * Calculates the discrete Shannon conditional mutual information, using
264     * histogram probability estimators. Arrays must be the same length.
265     * @param <T1> Type contained in the first array.
266     * @param <T2> Type contained in the second array.
267     * @param <T3> Type contained in the condition array.
268     * @param first An array of values.
269     * @param second Another array of values.
270     * @param condition Array to condition upon.
271     * @return The conditional mutual information I(first;second|condition)
272     */
273    public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
274        return innerConditionalMI(first,second,condition).score;
275    }
276
277    /**
278     * Calculates the discrete Shannon conditional mutual information, using
279     * histogram probability estimators. Note this calculates I(T1;T2|T3).
280     * @param <T1> Type of the first variable.
281     * @param <T2> Type of the second variable.
282     * @param <T3> Type of the condition variable.
283     * @param rv The triple random variable of the three inputs.
284     * @return The conditional mutual information I(first;second|condition)
285     */
286    public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv) {
287        return innerConditionalMI(rv,false).score;
288    }
289
290    /**
291     * Calculates the discrete Shannon conditional mutual information, using
292     * histogram probability estimators. Note this calculates I(T1;T3|T2).
293     * @param <T1> Type of the first variable.
294     * @param <T2> Type of the condition variable.
295     * @param <T3> Type of the second variable.
296     * @param rv The triple random variable of the three inputs.
297     * @return The conditional mutual information I(first;second|condition)
298     */
299    public static <T1,T2,T3> double conditionalMIFlipped(TripleDistribution<T1,T2,T3> rv) {
300        return innerConditionalMI(rv,true).score;
301    }
302
303    /**
304     * Calculates the mutual information from a joint random variable.
305     * @param pairDist The joint distribution.
306     * @param <T1> The first type.
307     * @param <T2> The second type.
308     * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable.
309     */
310    private static <T1,T2> ScoreStateCountTuple innerMI(PairDistribution<T1,T2> pairDist) {
311        Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts;
312        Map<T1,MutableLong> firstCountDist = pairDist.firstCount;
313        Map<T2,MutableLong> secondCountDist = pairDist.secondCount;
314
315        double vectorLength = pairDist.count;
316        double mi = 0.0;
317        boolean error = false;
318        for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
319            double jointCount = e.getValue().doubleValue();
320            double prob = jointCount / vectorLength;
321            double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue();
322            double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue();
323
324            double top = vectorLength * jointCount;
325            double bottom = firstProb * secondProb;
326            double ratio = top/bottom;
327            double logRatio = Math.log(ratio);
328
329            if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) {
330                logger.log(Level.WARNING, "State = " + e.getKey().toString());
331                logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio);
332                error = true;
333            }
334            mi += prob * logRatio;
335            //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb));
336        }
337        mi /= LOG_BASE;
338
339        double stateRatio = vectorLength / countDist.size();
340        if (stateRatio < SAMPLES_RATIO) {
341            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
342        }
343        
344        if (error) {
345            logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found"));
346        }
347        
348        return new ScoreStateCountTuple(mi,countDist.size());
349    }
350
351    /**
352     * Calculates the mutual information between the two lists.
353     * @param first The first list.
354     * @param second The second list.
355     * @param <T1> The first type.
356     * @param <T2> The second type.
357     * @return A ScoreStateCountTuple containing the mutual information and the number of states in the joint variable.
358     */
359    private static <T1,T2> ScoreStateCountTuple innerMI(List<T1> first, List<T2> second) {
360        if (first.size() == second.size()) {
361            PairDistribution<T1,T2> pairDist = PairDistribution.constructFromLists(first, second);
362            
363            return innerMI(pairDist);
364        } else {
365            throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
366        }
367    }
368    
369    /**
370     * Calculates the discrete Shannon mutual information, using histogram 
371     * probability estimators. Arrays must be the same length.
372     * @param <T1> Type of the first array
373     * @param <T2> Type of the second array
374     * @param first An array of values
375     * @param second Another array of values
376     * @return The mutual information I(first;second)
377     */
378    public static <T1,T2> double mi(List<T1> first, List<T2> second) {
379        return innerMI(first,second).score;
380    }
381
382    /**
383     * Calculates the discrete Shannon mutual information, using histogram 
384     * probability estimators.
385     * @param <T1> Type of the first variable
386     * @param <T2> Type of the second variable
387     * @param pairDist PairDistribution for the two variables.
388     * @return The mutual information I(first;second)
389     */
390    public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist) {
391        return innerMI(pairDist).score;
392    }
393
394    /**
395     * Calculates the Shannon joint entropy of two arrays, using histogram 
396     * probability estimators. Arrays must be same length.
397     * @param <T1> Type of the first array.
398     * @param <T2> Type of the second array.
399     * @param first An array of values.
400     * @param second Another array of values.
401     * @return The entropy H(first,second)
402     */
403    public static <T1,T2> double jointEntropy(List<T1> first, List<T2> second) {
404        if (first.size() == second.size()) {
405            double vectorLength = first.size();
406            double jointEntropy = 0.0;
407            
408            PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(first,second); 
409            Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts;
410
411            for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
412                double prob = e.getValue().doubleValue() / vectorLength;
413
414                jointEntropy -= prob * Math.log(prob);
415            }
416            jointEntropy /= LOG_BASE;
417
418            double stateRatio = vectorLength / countDist.size();
419            if (stateRatio < SAMPLES_RATIO) {
420                logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio});
421            }
422            
423            return jointEntropy;
424        } else {
425            throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
426        }
427    }
428    
429    /**
430     * Calculates the discrete Shannon conditional entropy of two arrays, using
431     * histogram probability estimators. Arrays must be the same length.
432     * @param <T1> Type of the first array.
433     * @param <T2> Type of the second array.
434     * @param vector The main array of values.
435     * @param condition The array to condition on.
436     * @return The conditional entropy H(vector|condition).
437     */
438    public static <T1,T2> double conditionalEntropy(List<T1> vector, List<T2> condition) {
439        if (vector.size() == condition.size()) {
440            double vectorLength = vector.size();
441            double condEntropy = 0.0;
442            
443            PairDistribution<T1,T2> countPair = PairDistribution.constructFromLists(vector,condition); 
444            Map<CachedPair<T1,T2>,MutableLong> countDist = countPair.jointCounts;
445            Map<T2,MutableLong> conditionCountDist = countPair.secondCount;
446
447            for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
448                double prob = e.getValue().doubleValue() / vectorLength;
449                double condProb = conditionCountDist.get(e.getKey().getB()).doubleValue() / vectorLength;
450
451                condEntropy -= prob * Math.log(prob/condProb);
452            }
453            condEntropy /= LOG_BASE;
454
455            double stateRatio = vectorLength / countDist.size();
456            if (stateRatio < SAMPLES_RATIO) {
457                logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio});
458            }
459            
460            return condEntropy;
461        } else {
462            throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size());
463        }
464    }
465
466    /**
467     * Calculates the discrete Shannon entropy, using histogram probability 
468     * estimators.
469     * @param <T> Type of the array.
470     * @param vector The array of values.
471     * @return The entropy H(vector).
472     */
473    public static <T> double entropy(List<T> vector) {
474        double vectorLength = vector.size();
475        double entropy = 0.0;
476
477        Map<T,Long> countDist = calculateCountDist(vector);
478        for (Entry<T,Long> e : countDist.entrySet()) {
479            double prob = e.getValue() / vectorLength;
480            entropy -= prob * Math.log(prob);
481        }
482        entropy /= LOG_BASE;
483
484        double stateRatio = vectorLength / countDist.size();
485        if (stateRatio < SAMPLES_RATIO) {
486            logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio});
487        }
488        
489        return entropy;
490    }
491
492    /**
493     * Generate the counts for a single vector.
494     * @param <T> The type inside the vector.
495     * @param vector An array of values.
496     * @return A HashMap from states of T to counts.
497     */
498    public static <T> Map<T,Long> calculateCountDist(List<T> vector) {
499        HashMap<T,Long> countDist = new HashMap<>(DEFAULT_MAP_SIZE);
500        for (T e : vector) {
501            Long curCount = countDist.getOrDefault(e,0L);
502            curCount += 1;
503            countDist.put(e, curCount);
504        }
505
506        return countDist;
507    }
508
509    /**
510     * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is
511     * an element of the same probability distribution.
512     * @param vector The probability distribution.
513     * @return The entropy.
514     */
515    public static double calculateEntropy(Stream<Double> vector) {
516        return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).reduce(0.0, Double::sum);
517    }
518
519    /**
520     * Calculates the discrete Shannon entropy of a stream, assuming each element of the stream is
521     * an element of the same probability distribution.
522     * @param vector The probability distribution.
523     * @return The entropy.
524     */
525    public static double calculateEntropy(DoubleStream vector) {
526        return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).sum();
527    }
528
529    /**
530     * A tuple of the information theoretic value, along with the number of
531     * states in the random variable.
532     */
533    private static class ScoreStateCountTuple {
534        public final double score;
535        public final int stateCount;
536
537        public ScoreStateCountTuple(double score, int stateCount) {
538            this.score = score;
539            this.stateCount = stateCount;
540        }
541
542        @Override
543        public String toString() {
544            return "ScoreStateCount(score=" + score + ",stateCount=" + stateCount + ")";
545        }
546    }
547
548    /**
549     * An immutable named tuple containing the statistics from a G test.
550     */
551    public static final class GTestStatistics {
552        public final double gStatistic;
553        public final int numStates;
554        public final double probability;
555
556        public GTestStatistics(double gStatistic, int numStates, double probability) {
557            this.gStatistic = gStatistic;
558            this.numStates = numStates;
559            this.probability = probability;
560        }
561
562        @Override
563        public String toString() {
564            return "GTest(statistic="+gStatistic+",probability="+probability+",numStates="+numStates+")";
565        }
566    }
567}
568