001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.example; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import org.tribuo.util.infotheory.InformationTheory; 024import org.tribuo.util.infotheory.impl.CachedTriple; 025 026import java.util.ArrayList; 027import java.util.List; 028import java.util.Random; 029import java.util.logging.Level; 030import java.util.logging.Logger; 031 032/** 033 * Demo showing how to calculate various mutual informations and entropies. 034 */ 035public class InformationTheoryDemo { 036 037 private static final Logger logger = Logger.getLogger(InformationTheoryDemo.class.getName()); 038 039 private static final Random rng = new Random(1); 040 041 /** 042 * Generates a sample from a uniform distribution over the integers. 043 * @param length The number of samples. 044 * @param alphabetSize The alphabet size (i.e., the number of unique values). 045 * @return A sample from a uniform distribution. 046 */ 047 public static List<Integer> generateUniform(int length, int alphabetSize) { 048 List<Integer> vector = new ArrayList<>(length); 049 050 for (int i = 0; i < length; i++) { 051 vector.add(i,rng.nextInt(alphabetSize)); 052 } 053 054 return vector; 055 } 056 057 /** 058 * Generates a sample from a three variable XOR function. 059 * <p> 060 * Each list is a binary variable, and the third is the XOR of the first two. 061 * @param length The number of samples. 062 * @return A sample from an XOR function. 063 */ 064 public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateXOR(int length) { 065 List<Integer> first = new ArrayList<>(length); 066 List<Integer> second = new ArrayList<>(length); 067 List<Integer> xor = new ArrayList<>(length); 068 069 for (int i = 0; i < length; i++) { 070 int firstVal = rng.nextInt(2); 071 int secondVal = rng.nextInt(2); 072 int xorVal = firstVal ^ secondVal; 073 first.add(i,firstVal); 074 second.add(i,secondVal); 075 xor.add(i,xorVal); 076 } 077 078 return new CachedTriple<>(first,second,xor); 079 } 080 081 /** 082 * These correlations don't map to mutual information values, as if xyDraw is above xyCorrelation then the draw is completely random. 083 * <p> 084 * To make it generate correlations of a specific mutual information then it needs to specify the full joint distribution and draw from that. 085 * @param length The number of samples. 086 * @param alphabetSize The alphabet size (i.e., the number of unique values). 087 * @param xyCorrelation Value between 0.0 and 1.0 specifying how likely it is that Y has the same value as X. 088 * @param xzCorrelation Value between 0.0 and 1.0 specifying how likely it is that Z has the same value as X. 089 * @return A triple of samples drawn from correlated random variables. 090 */ 091 public static CachedTriple<List<Integer>,List<Integer>,List<Integer>> generateCorrelated(int length, int alphabetSize, double xyCorrelation, double xzCorrelation) { 092 List<Integer> first = new ArrayList<>(length); 093 List<Integer> second = new ArrayList<>(length); 094 List<Integer> third = new ArrayList<>(length); 095 096 for (int i = 0; i < length; i++) { 097 int firstVal = rng.nextInt(alphabetSize); 098 first.add(firstVal); 099 100 double xyDraw = rng.nextDouble(); 101 if (xyDraw < xyCorrelation) { 102 second.add(firstVal); 103 } else { 104 second.add(rng.nextInt(alphabetSize)); 105 } 106 107 double xzDraw = rng.nextDouble(); 108 if (xzDraw < xzCorrelation) { 109 third.add(firstVal); 110 } else { 111 third.add(rng.nextInt(alphabetSize)); 112 } 113 } 114 115 return new CachedTriple<>(first,second,third); 116 } 117 118 public enum DistributionType { RANDOM, XOR, CORRELATED } 119 120 public static class DemoOptions implements Options { 121 @Override 122 public String getOptionsDescription() { 123 return "A demo class showing how to calculate various mutual informations from different inputs."; 124 } 125 @Option(charName='t',longName="type",usage="The type of the input distribution.") 126 public DistributionType type = DistributionType.RANDOM; 127 } 128 129 public static void main(String[] args) { 130 131 DemoOptions options = new DemoOptions(); 132 133 try { 134 ConfigurationManager cm = new ConfigurationManager(args, options, false); 135 } catch (UsageException e) { 136 System.out.println(e.getUsage()); 137 } 138 139 List<Integer> x; 140 List<Integer> y; 141 List<Integer> z; 142 143 switch (options.type) { 144 case RANDOM: 145 x = generateUniform(1000, 5); 146 y = generateUniform(1000, 5); 147 z = generateUniform(1000, 5); 148 break; 149 case XOR: 150 CachedTriple<List<Integer>,List<Integer>,List<Integer>> trip = generateXOR(1000); 151 x = trip.getA(); 152 y = trip.getB(); 153 z = trip.getC(); 154 break; 155 case CORRELATED: 156 CachedTriple<List<Integer>,List<Integer>,List<Integer>> tripC = generateCorrelated(1000,5,0.7,0.5); 157 x = tripC.getA(); 158 y = tripC.getB(); 159 z = tripC.getC(); 160 break; 161 default: 162 logger.log(Level.WARNING, "Unknown test case, exiting"); 163 return; 164 } 165 166 double hx = InformationTheory.entropy(x); 167 double hy = InformationTheory.entropy(y); 168 double hz = InformationTheory.entropy(z); 169 170 double hxy = InformationTheory.jointEntropy(x,y); 171 double hxz = InformationTheory.jointEntropy(x,z); 172 double hyz = InformationTheory.jointEntropy(y,z); 173 174 double ixy = InformationTheory.mi(x,y); 175 double ixz = InformationTheory.mi(x,z); 176 double iyz = InformationTheory.mi(y,z); 177 178 InformationTheory.GTestStatistics gxy = InformationTheory.gTest(x,y,null); 179 InformationTheory.GTestStatistics gxz = InformationTheory.gTest(x,z,null); 180 InformationTheory.GTestStatistics gyz = InformationTheory.gTest(y,z,null); 181 182 if (InformationTheory.LOG_BASE == InformationTheory.LOG_2) { 183 logger.log(Level.INFO, "Using log_2"); 184 } else if (InformationTheory.LOG_BASE == InformationTheory.LOG_E) { 185 logger.log(Level.INFO, "Using log_e"); 186 } else { 187 logger.log(Level.INFO, "Using unexpected log base, LOG_BASE = " + InformationTheory.LOG_BASE); 188 } 189 190 logger.log(Level.INFO, "The entropy of X, H(X) is " + hx); 191 logger.log(Level.INFO, "The entropy of Y, H(Y) is " + hy); 192 logger.log(Level.INFO, "The entropy of Z, H(Z) is " + hz); 193 194 logger.log(Level.INFO, "The joint entropy of X and Y, H(X,Y) is " + hxy); 195 logger.log(Level.INFO, "The joint entropy of X and Z, H(X,Z) is " + hxz); 196 logger.log(Level.INFO, "The joint entropy of Y and Z, H(Y,Z) is " + hyz); 197 198 logger.log(Level.INFO, "The mutual information between X and Y, I(X;Y) is " + ixy); 199 logger.log(Level.INFO, "The mutual information between X and Z, I(X;Z) is " + ixz); 200 logger.log(Level.INFO, "The mutual information between Y and Z, I(Y;Z) is " + iyz); 201 202 logger.log(Level.INFO, "The G-Test between X and Y, G(X;Y) is " + gxy); 203 logger.log(Level.INFO, "The G-Test between X and Z, G(X;Z) is " + gxz); 204 logger.log(Level.INFO, "The G-Test between Y and Z, G(Y;Z) is " + gyz); 205 } 206 207}