001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.impl.CachedPair;
021import org.tribuo.util.infotheory.impl.CachedTriple;
022import org.tribuo.util.infotheory.impl.PairDistribution;
023import org.tribuo.util.infotheory.impl.TripleDistribution;
024import org.tribuo.util.infotheory.impl.WeightCountTuple;
025import org.tribuo.util.infotheory.impl.WeightedPairDistribution;
026import org.tribuo.util.infotheory.impl.WeightedTripleDistribution;
027
028import java.util.ArrayList;
029import java.util.LinkedHashMap;
030import java.util.List;
031import java.util.Map;
032import java.util.Map.Entry;
033import java.util.logging.Level;
034import java.util.logging.Logger;
035
036/**
037 * A class of (discrete) weighted information theoretic functions. Gives warnings if
038 * there are insufficient samples to estimate the quantities accurately.
039 * <p>
040 * Defaults to log_2, so returns values in bits.
041 * <p>
042 * All functions expect that the element types have well defined equals and
043 * hashcode, and that equals is consistent with hashcode. The behaviour is undefined
044 * if this is not true.
045 */
046public final class WeightedInformationTheory {
047    private static final Logger logger = Logger.getLogger(WeightedInformationTheory.class.getName());
048
049    public static final double SAMPLES_RATIO = 5.0;
050    public static final int DEFAULT_MAP_SIZE = 20;
051    public static final double LOG_2 = Math.log(2);
052    public static final double LOG_E = Math.log(Math.E);
053
054    /**
055     * Sets the base of the logarithm used in the information theoretic calculations.
056     * For LOG_2 the unit is "bit", for LOG_E the unit is "nat".
057     */
058    public static double LOG_BASE = LOG_2;
059
060    /**
061     * Chooses which variable is the one with associated weights.
062     */
063    public enum VariableSelector {
064        FIRST, SECOND, THIRD
065    }
066
067    /**
068     * Private constructor, only has static methods.
069     */
070    private WeightedInformationTheory() {}
071
072    /**
073     * Calculates the discrete weighted joint mutual information, using
074     * histogram probability estimators. Arrays must be the same length.
075     * @param <T1> Type contained in the first array.
076     * @param <T2> Type contained in the second array.
077     * @param <T3> Type contained in the target array.
078     * @param first An array of values.
079     * @param second Another array of values.
080     * @param target Target array of values.
081     * @param weights Array of weight values.
082     * @return The mutual information I(first,second;joint)
083     */
084    public static <T1,T2,T3> double jointMI(List<T1> first, List<T2> second, List<T3> target, List<Double> weights) {
085        WeightedTripleDistribution<T1, T2, T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, target, weights);
086
087        return jointMI(tripleRV);
088    }
089
090    public static <T1,T2,T3> double jointMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) {
091        Map<CachedTriple<T1,T2,T3>, WeightCountTuple> jointCount = tripleRV.getJointCount();
092        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = tripleRV.getABCount();
093        Map<T3,WeightCountTuple> cCount = tripleRV.getCCount();
094
095        double vectorLength = tripleRV.count;
096        double jmi = 0.0;
097        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
098            double jointCurCount = e.getValue().count;
099            double jointCurWeight = e.getValue().weight;
100            double prob = jointCurCount / vectorLength;
101            CachedPair<T1,T2> pair = e.getKey().getAB();
102            double abCurCount = abCount.get(pair).count;
103            double cCurCount = cCount.get(e.getKey().getC()).count;
104
105            jmi += jointCurWeight * prob * Math.log((vectorLength*jointCurCount)/(abCurCount*cCurCount));
106        }
107        jmi /= LOG_BASE;
108
109        double stateRatio = vectorLength / jointCount.size();
110        if (stateRatio < SAMPLES_RATIO) {
111            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}", new Object[]{jmi, stateRatio});
112        }
113        
114        return jmi;
115    }
116
117    public static <T1,T2,T3> double jointMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs){
118        Double boxedWeight;
119        double vecLength = rv.count;
120        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
121        Map<CachedPair<T1,T2>,MutableLong> abCount = rv.getABCount();
122        Map<T3,MutableLong> cCount = rv.getCCount();
123
124        double jmi = 0.0;
125        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
126            double jointCurCount = e.getValue().doubleValue();
127            double prob = jointCurCount / vecLength;
128            CachedPair<T1,T2> pair = new CachedPair<>(e.getKey().getA(),e.getKey().getB());
129            double abCurCount = abCount.get(pair).doubleValue();
130            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
131
132            double weight = 1.0;
133            switch (vs) {
134                case FIRST:
135                    boxedWeight = weights.get(e.getKey().getA());
136                    weight = boxedWeight == null ? 1.0 : boxedWeight;
137                    break;
138                case SECOND:
139                    boxedWeight = weights.get(e.getKey().getB());
140                    weight = boxedWeight == null ? 1.0 : boxedWeight;
141                    break;
142                case THIRD:
143                    boxedWeight = weights.get(e.getKey().getC());
144                    weight = boxedWeight == null ? 1.0 : boxedWeight;
145                    break;
146            }
147
148            jmi += weight * prob * Math.log((vecLength*jointCurCount)/(abCurCount*cCurCount));
149        }
150        jmi /= LOG_BASE;
151
152        double stateRatio = vecLength / jointCount.size();
153        if (stateRatio < SAMPLES_RATIO) {
154            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()});
155        }
156
157        return jmi;
158    }
159
160    /**
161     * Calculates the discrete weighted conditional mutual information, using
162     * histogram probability estimators. Arrays must be the same length.
163     * @param <T1> Type contained in the first array.
164     * @param <T2> Type contained in the second array.
165     * @param <T3> Type contained in the condition array.
166     * @param first An array of values.
167     * @param second Another array of values.
168     * @param condition Array to condition upon.
169     * @param weights Array of weight values.
170     * @return The conditional mutual information I(first;second|condition)
171     */
172    public static <T1,T2,T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition, List<Double> weights) {
173        if ((first.size() == second.size()) && (first.size() == condition.size()) && (first.size() == weights.size())) {
174            WeightedTripleDistribution<T1,T2,T3> tripleRV = WeightedTripleDistribution.constructFromLists(first, second, condition, weights);
175
176            return conditionalMI(tripleRV);
177        } else {
178            throw new IllegalArgumentException("Weighted Conditional Mutual Information requires four vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size() + ", weights.size() = "+ weights.size());
179        }
180    }
181
182    public static <T1,T2,T3> double conditionalMI(WeightedTripleDistribution<T1,T2,T3> tripleRV) {
183        Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = tripleRV.getJointCount();
184        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = tripleRV.getACCount();
185        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = tripleRV.getBCCount();
186        Map<T3,WeightCountTuple> cCount = tripleRV.getCCount();
187
188        double vectorLength = tripleRV.count;
189        double cmi = 0.0;
190        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
191            double weight = e.getValue().weight;
192            double jointCurCount = e.getValue().count;
193            double prob = jointCurCount / vectorLength;
194            CachedPair<T1,T3> acPair = e.getKey().getAC();
195            CachedPair<T2,T3> bcPair = e.getKey().getBC();
196            double acCurCount = acCount.get(acPair).count;
197            double bcCurCount = bcCount.get(bcPair).count;
198            double cCurCount = cCount.get(e.getKey().getC()).count;
199
200            cmi += weight * prob * Math.log((cCurCount*jointCurCount)/(acCurCount*bcCurCount));
201        }
202        cmi /= LOG_BASE;
203
204        double stateRatio = vectorLength / jointCount.size();
205        if (stateRatio < SAMPLES_RATIO) {
206            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
207        }
208
209        return cmi;
210    }
211
212    public static <T1,T2,T3> double conditionalMI(TripleDistribution<T1,T2,T3> rv, Map<?,Double> weights, VariableSelector vs) {
213        Double boxedWeight;
214        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = rv.getJointCount();
215        Map<CachedPair<T1,T3>,MutableLong> acCount = rv.getACCount();
216        Map<CachedPair<T2,T3>,MutableLong> bcCount = rv.getBCCount();
217        Map<T3,MutableLong> cCount = rv.getCCount();
218
219        double vectorLength = rv.count;
220        double cmi = 0.0;
221        for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
222            double jointCurCount = e.getValue().doubleValue();
223            double prob = jointCurCount / vectorLength;
224            CachedPair<T1, T3> acPair = new CachedPair<>(e.getKey().getA(), e.getKey().getC());
225            CachedPair<T2, T3> bcPair = new CachedPair<>(e.getKey().getB(), e.getKey().getC());
226            double acCurCount = acCount.get(acPair).doubleValue();
227            double bcCurCount = bcCount.get(bcPair).doubleValue();
228            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
229
230            double weight = 1.0;
231            switch (vs) {
232                case FIRST:
233                    boxedWeight = weights.get(e.getKey().getA());
234                    weight = boxedWeight == null ? 1.0 : boxedWeight;
235                    break;
236                case SECOND:
237                    boxedWeight = weights.get(e.getKey().getB());
238                    weight = boxedWeight == null ? 1.0 : boxedWeight;
239                    break;
240                case THIRD:
241                    boxedWeight = weights.get(e.getKey().getC());
242                    weight = boxedWeight == null ? 1.0 : boxedWeight;
243                    break;
244            }
245            cmi += weight * prob * Math.log((cCurCount * jointCurCount) / (acCurCount * bcCurCount));
246        }
247        cmi /= LOG_BASE;
248
249        double stateRatio = vectorLength / jointCount.size();
250        if (stateRatio < SAMPLES_RATIO) {
251            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
252        }
253
254        return cmi;
255    }
256
257    /**
258     * Calculates the discrete weighted mutual information, using histogram
259     * probability estimators.
260     * <p>
261     * Arrays must be the same length.
262     * @param <T1> Type of the first array
263     * @param <T2> Type of the second array
264     * @param first An array of values
265     * @param second Another array of values
266     * @param weights Array of weight values.
267     * @return The mutual information I(first;Second)
268     */
269    public static <T1,T2> double mi(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) {
270        if ((first.size() == second.size()) && (first.size() == weights.size())) {
271            WeightedPairDistribution<T1,T2> countPair = WeightedPairDistribution.constructFromLists(first,second,weights);
272            return mi(countPair);
273        } else {
274            throw new IllegalArgumentException("Weighted Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
275        }
276    }
277
278    public static <T1,T2> double mi(WeightedPairDistribution<T1,T2> jointDist) {
279        double vectorLength = jointDist.count;
280        double mi = 0.0;
281        Map<CachedPair<T1,T2>,WeightCountTuple> countDist = jointDist.getJointCounts();
282        Map<T1,WeightCountTuple> firstCountDist = jointDist.getFirstCount();
283        Map<T2,WeightCountTuple> secondCountDist = jointDist.getSecondCount();
284
285        for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
286            double weight = e.getValue().weight;
287            double jointCount = e.getValue().count;
288            double prob = jointCount / vectorLength;
289            double firstCount = firstCountDist.get(e.getKey().getA()).count;
290            double secondCount = secondCountDist.get(e.getKey().getB()).count;
291
292            mi += weight * prob * Math.log((vectorLength*jointCount)/(firstCount*secondCount));
293        }
294        mi /= LOG_BASE;
295
296        double stateRatio = vectorLength / countDist.size();
297        if (stateRatio < SAMPLES_RATIO) {
298            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
299        }
300
301        return mi;
302    }
303
304    public static <T1,T2> double mi(PairDistribution<T1,T2> pairDist, Map<?,Double> weights, VariableSelector vs) {
305        if (vs == VariableSelector.THIRD) {
306            throw new IllegalArgumentException("MI only has two variables");
307        }
308        Map<CachedPair<T1,T2>,MutableLong> countDist = pairDist.jointCounts;
309        Map<T1,MutableLong> firstCountDist = pairDist.firstCount;
310        Map<T2,MutableLong> secondCountDist = pairDist.secondCount;
311
312        Double boxedWeight;
313        double vectorLength = pairDist.count;
314        double mi = 0.0;
315        boolean error = false;
316        for (Entry<CachedPair<T1,T2>,MutableLong> e : countDist.entrySet()) {
317            double jointCount = e.getValue().doubleValue();
318            double prob = jointCount / vectorLength;
319            double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue();
320            double secondProb = secondCountDist.get(e.getKey().getB()).doubleValue();
321
322            double top = vectorLength * jointCount;
323            double bottom = firstProb * secondProb;
324            double ratio = top/bottom;
325            double logRatio = Math.log(ratio);
326
327            if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) {
328                logger.log(Level.WARNING, "State = " + e.getKey().toString());
329                logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio);
330                error = true;
331            }
332
333            double weight = 1.0;
334            switch (vs) {
335                case FIRST:
336                    boxedWeight = weights.get(e.getKey().getA());
337                    weight = boxedWeight == null ? 1.0 : boxedWeight;
338                    break;
339                case SECOND:
340                    boxedWeight = weights.get(e.getKey().getB());
341                    weight = boxedWeight == null ? 1.0 : boxedWeight;
342                    break;
343                default:
344                    throw new IllegalArgumentException("VariableSelector.THIRD not allowed in a two variable calculation.");
345            }
346            mi += weight * prob * logRatio;
347            //mi += prob * Math.log((vectorLength*jointCount)/(firstProb*secondProb));
348        }
349        mi /= LOG_BASE;
350
351        double stateRatio = vectorLength / countDist.size();
352        if (stateRatio < SAMPLES_RATIO) {
353            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
354        }
355
356        if (error) {
357            logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found"));
358        }
359
360        return mi;
361    }
362
363    /**
364     * Calculates the Shannon/Guiasu weighted joint entropy of two arrays, 
365     * using histogram probability estimators. 
366     * <p>
367     * Arrays must be same length.
368     * @param <T1> Type of the first array.
369     * @param <T2> Type of the second array.
370     * @param first An array of values.
371     * @param second Another array of values.
372     * @param weights Array of weight values.
373     * @return The entropy H(first,second)
374     */
375    public static <T1,T2> double jointEntropy(ArrayList<T1> first, ArrayList<T2> second, ArrayList<Double> weights) {
376        if ((first.size() == second.size()) && (first.size() == weights.size())) {
377            double vectorLength = first.size();
378            double jointEntropy = 0.0;
379            
380            WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(first,second,weights);
381            Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts();
382
383            for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
384                double prob = e.getValue().count / vectorLength;
385                double weight = e.getValue().weight;
386
387                jointEntropy -= weight * prob * Math.log(prob);
388            }
389            jointEntropy /= LOG_BASE;
390
391            double stateRatio = vectorLength / countDist.size();
392            if (stateRatio < SAMPLES_RATIO) {
393                logger.log(Level.INFO, "Weighted Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio});
394            }
395            
396            return jointEntropy;
397        } else {
398            throw new IllegalArgumentException("Weighted Joint Entropy requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
399        }
400    }
401    
402    /**
403     * Calculates the discrete Shannon/Guiasu Weighted Conditional Entropy of 
404     * two arrays, using histogram probability estimators. 
405     * <p>
406     * Arrays must be the same length.
407     * @param <T1> Type of the first array.
408     * @param <T2> Type of the second array.
409     * @param vector The main array of values.
410     * @param condition The array to condition on.
411     * @param weights Array of weight values.
412     * @return The weighted conditional entropy H_w(vector|condition).
413     */
414    public static <T1,T2> double weightedConditionalEntropy(ArrayList<T1> vector, ArrayList<T2> condition, ArrayList<Double> weights) {
415        if ((vector.size() == condition.size()) && (vector.size() == weights.size())) {
416            double vectorLength = vector.size();
417            double condEntropy = 0.0;
418            
419            WeightedPairDistribution<T1,T2> pairDist = WeightedPairDistribution.constructFromLists(vector,condition,weights);
420            Map<CachedPair<T1,T2>,WeightCountTuple> countDist = pairDist.getJointCounts();
421            Map<T2,WeightCountTuple> conditionCountDist = pairDist.getSecondCount();
422
423            for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
424                double prob = e.getValue().count / vectorLength;
425                double condProb = conditionCountDist.get(e.getKey().getB()).count / vectorLength;
426                double weight = e.getValue().weight;
427
428                condEntropy -= weight * prob * Math.log(prob/condProb);
429            }
430            condEntropy /= LOG_BASE;
431
432            double stateRatio = vectorLength / countDist.size();
433            if (stateRatio < SAMPLES_RATIO) {
434                logger.log(Level.INFO, "Weighted Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio});
435            }
436            
437            return condEntropy;
438        } else {
439            throw new IllegalArgumentException("Weighted Conditional Entropy requires three vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size() + ", weights.size() = " + weights.size());
440        }
441    }
442
443    /**
444     * Calculates the discrete Shannon/Guiasu Weighted Entropy, using histogram 
445     * probability estimators.
446     * @param <T> Type of the array.
447     * @param vector The array of values.
448     * @param weights Array of weight values.
449     * @return The weighted entropy H_w(vector).
450     */
451    public static <T> double weightedEntropy(ArrayList<T> vector, ArrayList<Double> weights) {
452        if (vector.size() == weights.size()) {
453            double vectorLength = vector.size();
454            double entropy = 0.0;
455
456            Map<T,WeightCountTuple> countDist = calculateWeightedCountDist(vector,weights);
457            for (Entry<T,WeightCountTuple> e : countDist.entrySet()) {
458                long count = e.getValue().count;
459                double weight = e.getValue().weight;
460                double prob = count / vectorLength;
461                entropy -= weight * prob * Math.log(prob);
462            }
463            entropy /= LOG_BASE;
464
465            double stateRatio = vectorLength / countDist.size();
466            if (stateRatio < SAMPLES_RATIO) {
467                logger.log(Level.INFO, "Weighted Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio});
468            }
469            
470            return entropy;
471        } else {
472            throw new IllegalArgumentException("Weighted Entropy requires two vectors the same length. vector.size() = " + vector.size() + ",weights.size() = " + weights.size());
473        }
474    }
475
476    /**
477     * Generate the counts for a single vector.
478     * @param <T> The type inside the vector.
479     * @param vector An array of values.
480     * @param weights The array of weight values.
481     * @return A HashMap from states of T to Pairs of count and total weight for that state.
482     */
483    public static <T> Map<T,WeightCountTuple> calculateWeightedCountDist(ArrayList<T> vector, ArrayList<Double> weights) {
484        Map<T,WeightCountTuple> dist = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
485        for (int i = 0; i < vector.size(); i++) {
486            T e = vector.get(i);
487            Double weight = weights.get(i);
488            WeightCountTuple curVal = dist.computeIfAbsent(e,(k) -> new WeightCountTuple());
489            curVal.count += 1;
490            curVal.weight += weight;
491        }
492
493        normaliseWeights(dist);
494
495        return dist;
496    }
497
498    /**
499     * Normalizes the weights in the map, i.e., divides each weight by it's count.
500     * @param map The map to normalize.
501     * @param <T> The type of the variable that was counted.
502     */
503    public static <T> void normaliseWeights(Map<T,WeightCountTuple> map) {
504        for (Entry<T,WeightCountTuple> e : map.entrySet()) {
505            WeightCountTuple tuple = e.getValue();
506            tuple.weight /= tuple.count;
507        }
508    }
509    
510}