001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.example;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import org.tribuo.util.infotheory.InformationTheory;
024import org.tribuo.util.infotheory.impl.CachedTriple;
025
026import java.util.ArrayList;
027import java.util.List;
028import java.util.Random;
029import java.util.logging.Level;
030import java.util.logging.Logger;
031
032/**
033 * Demo showing how to calculate various mutual informations and entropies.
034 */
035public class InformationTheoryDemo {
036
037    private static final Logger logger = Logger.getLogger(InformationTheoryDemo.class.getName());
038
039    private static final Random rng = new Random(1);
040
041    /**
042     * Generates a sample from a uniform distribution over the integers.
043     * @param length The number of samples.
044     * @param alphabetSize The alphabet size (i.e., the number of unique values).
045     * @return A sample from a uniform distribution.
046     */
047    public static List<Integer> generateUniform(int length, int alphabetSize) {
048        List<Integer> vector = new ArrayList<>(length);
049
050        for (int i = 0; i < length; i++) {
051            vector.add(i,rng.nextInt(alphabetSize));
052        }
053
054        return vector;
055    }
056
057    /**
058     * Generates a sample from a three variable XOR function.
059     * <p>
060     * Each list is a binary variable, and the third is the XOR of the first two.
061     * @param length The number of samples.
062     * @return A sample from an XOR function.
063     */
064    public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateXOR(int length) {
065        List<Integer> first = new ArrayList<>(length);
066        List<Integer> second = new ArrayList<>(length);
067        List<Integer> xor = new ArrayList<>(length);
068
069        for (int i = 0; i < length; i++) {
070            int firstVal = rng.nextInt(2);
071            int secondVal = rng.nextInt(2);
072            int xorVal = firstVal ^ secondVal;
073            first.add(i,firstVal);
074            second.add(i,secondVal);
075            xor.add(i,xorVal);
076        }
077
078        return new CachedTriple<>(first,second,xor);
079    }
080
081    /**
082     * These correlations don't map to mutual information values, as if xyDraw is above xyCorrelation then the draw is completely random.
083     * <p>
084     * To make it generate correlations of a specific mutual information then it needs to specify the full joint distribution and draw from that.
085     * @param length The number of samples.
086     * @param alphabetSize The alphabet size (i.e., the number of unique values).
087     * @param xyCorrelation Value between 0.0 and 1.0 specifying how likely it is that Y has the same value as X.
088     * @param xzCorrelation Value between 0.0 and 1.0 specifying how likely it is that Z has the same value as X.
089     * @return A triple of samples drawn from correlated random variables.
090     */
091    public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateCorrelated(int length, int alphabetSize, double xyCorrelation, double xzCorrelation) {
092        List<Integer> first = new ArrayList<>(length);
093        List<Integer> second = new ArrayList<>(length);
094        List<Integer> third = new ArrayList<>(length);
095
096        for (int i = 0; i < length; i++) {
097            int firstVal = rng.nextInt(alphabetSize);
098            first.add(firstVal);
099
100            double xyDraw = rng.nextDouble();
101            if (xyDraw < xyCorrelation) {
102                second.add(firstVal);
103            } else {
104                second.add(rng.nextInt(alphabetSize));
105            }
106
107            double xzDraw = rng.nextDouble();
108            if (xzDraw < xzCorrelation) {
109                third.add(firstVal);
110            } else {
111                third.add(rng.nextInt(alphabetSize));
112            }
113        }
114
115        return new CachedTriple<>(first,second,third);
116    }
117
118    public enum DistributionType { RANDOM, XOR, CORRELATED }
119
120    public static class DemoOptions implements Options {
121        @Override
122        public String getOptionsDescription() {
123            return "A demo class showing how to calculate various mutual informations from different inputs.";
124        }
125        @Option(charName='t',longName="type",usage="The type of the input distribution.")
126        public DistributionType type = DistributionType.RANDOM;
127    }
128
129    public static void main(String[] args) {
130
131        DemoOptions options = new DemoOptions();
132
133        try {
134            ConfigurationManager cm = new ConfigurationManager(args, options, false);
135        } catch (UsageException e) {
136            System.out.println(e.getUsage());
137        }
138
139        List<Integer> x;
140        List<Integer> y;
141        List<Integer> z;
142
143        switch (options.type) {
144            case RANDOM:
145                x = generateUniform(1000, 5);
146                y = generateUniform(1000, 5);
147                z = generateUniform(1000, 5);
148                break;
149            case XOR:
150                CachedTriple<List<Integer>,List<Integer>,List<Integer>> trip = generateXOR(1000);
151                x = trip.getA();
152                y = trip.getB();
153                z = trip.getC();
154                break;
155            case CORRELATED:
156                CachedTriple<List<Integer>,List<Integer>,List<Integer>> tripC = generateCorrelated(1000,5,0.7,0.5);
157                x = tripC.getA();
158                y = tripC.getB();
159                z = tripC.getC();
160                break;
161            default:
162                logger.log(Level.WARNING, "Unknown test case, exiting");
163                return;
164        }
165
166        double hx = InformationTheory.entropy(x);
167        double hy = InformationTheory.entropy(y);
168        double hz = InformationTheory.entropy(z);
169
170        double hxy = InformationTheory.jointEntropy(x,y);
171        double hxz = InformationTheory.jointEntropy(x,z);
172        double hyz = InformationTheory.jointEntropy(y,z);
173        
174        double ixy = InformationTheory.mi(x,y);
175        double ixz = InformationTheory.mi(x,z);
176        double iyz = InformationTheory.mi(y,z);
177        
178        InformationTheory.GTestStatistics gxy = InformationTheory.gTest(x,y,null);
179        InformationTheory.GTestStatistics gxz = InformationTheory.gTest(x,z,null);
180        InformationTheory.GTestStatistics gyz = InformationTheory.gTest(y,z,null);
181
182        if (InformationTheory.LOG_BASE == InformationTheory.LOG_2) {
183            logger.log(Level.INFO, "Using log_2");
184        } else if (InformationTheory.LOG_BASE == InformationTheory.LOG_E) {
185            logger.log(Level.INFO, "Using log_e");
186        } else {
187            logger.log(Level.INFO, "Using unexpected log base, LOG_BASE = " + InformationTheory.LOG_BASE);
188        }
189        
190        logger.log(Level.INFO, "The entropy of X, H(X) is " + hx);
191        logger.log(Level.INFO, "The entropy of Y, H(Y) is " + hy);
192        logger.log(Level.INFO, "The entropy of Z, H(Z) is " + hz);
193        
194        logger.log(Level.INFO, "The joint entropy of X and Y, H(X,Y) is " + hxy);
195        logger.log(Level.INFO, "The joint entropy of X and Z, H(X,Z) is " + hxz);
196        logger.log(Level.INFO, "The joint entropy of Y and Z, H(Y,Z) is " + hyz);
197
198        logger.log(Level.INFO, "The mutual information between X and Y, I(X;Y) is " + ixy);
199        logger.log(Level.INFO, "The mutual information between X and Z, I(X;Z) is " + ixz);
200        logger.log(Level.INFO, "The mutual information between Y and Z, I(Y;Z) is " + iyz);
201
202        logger.log(Level.INFO, "The G-Test between X and Y, G(X;Y) is " + gxy);
203        logger.log(Level.INFO, "The G-Test between X and Z, G(X;Z) is " + gxz);
204        logger.log(Level.INFO, "The G-Test between Y and Z, G(Y;Z) is " + gyz);
205    }
206    
207}