001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.util.infotheory.InformationTheory; 021 022import java.util.HashMap; 023import java.util.LinkedHashMap; 024import java.util.List; 025import java.util.Map; 026import java.util.Map.Entry; 027 028/** 029 * A count distribution over {@link CachedPair} objects. 030 * @param <T1> The type of the first element 031 * @param <T2> The type of the second element 032 */ 033public class PairDistribution<T1,T2> { 034 035 public final long count; 036 037 public final Map<CachedPair<T1,T2>,MutableLong> jointCounts; 038 public final Map<T1,MutableLong> firstCount; 039 public final Map<T2,MutableLong> secondCount; 040 041 public PairDistribution(long count, Map<CachedPair<T1,T2>,MutableLong> jointCounts, Map<T1,MutableLong> firstCount, Map<T2,MutableLong> secondCount) { 042 this.count = count; 043 this.jointCounts = new LinkedHashMap<>(jointCounts); 044 this.firstCount = new LinkedHashMap<>(firstCount); 045 this.secondCount = new LinkedHashMap<>(secondCount); 046 } 047 048 public PairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,MutableLong> jointCounts, LinkedHashMap<T1,MutableLong> firstCount, LinkedHashMap<T2,MutableLong> secondCount) { 049 this.count = count; 050 this.jointCounts = jointCounts; 051 this.firstCount = firstCount; 052 this.secondCount = secondCount; 053 } 054 055 /** 056 * Generates the counts for two vectors. Returns a PairDistribution containing the joint 057 * count, and the two marginal counts. 058 * @param <T1> Type of the first array. 059 * @param <T2> Type of the second array. 060 * @param first An array of values. 061 * @param second Another array of values. 062 * @return The joint counts and the two marginal counts. 063 */ 064 public static <T1,T2> PairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second) { 065 LinkedHashMap<CachedPair<T1,T2>,MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 066 LinkedHashMap<T1,MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 067 LinkedHashMap<T2,MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 068 069 if (first.size() == second.size()) { 070 long count = 0; 071 for (int i = 0; i < first.size(); i++) { 072 T1 a = first.get(i); 073 T2 b = second.get(i); 074 CachedPair<T1,T2> pair = new CachedPair<>(a,b); 075 076 MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong()); 077 abCount.increment(); 078 079 MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong()); 080 aCount.increment(); 081 082 MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong()); 083 bCount.increment(); 084 085 count++; 086 } 087 088 return new PairDistribution<>(count,abCountDist,aCountDist,bCountDist); 089 } else { 090 throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size()); 091 } 092 } 093 094 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount) { 095 Map<T1,MutableLong> aCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 096 Map<T2,MutableLong> bCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 097 098 return constructFromMap(jointCount,aCount,bCount); 099 } 100 101 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, int aSize, int bSize) { 102 Map<T1,MutableLong> aCount = new HashMap<>(aSize); 103 Map<T2,MutableLong> bCount = new HashMap<>(bSize); 104 105 return constructFromMap(jointCount,aCount,bCount); 106 } 107 108 public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, 109 Map<T1,MutableLong> aCount, 110 Map<T2,MutableLong> bCount) { 111 long count = 0L; 112 113 for (Entry<CachedPair<T1,T2>,MutableLong> e : jointCount.entrySet()) { 114 CachedPair<T1,T2> pair = e.getKey(); 115 long curCount = e.getValue().longValue(); 116 T1 a = pair.getA(); 117 T2 b = pair.getB(); 118 119 MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong()); 120 curACount.increment(curCount); 121 122 MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 123 curBCount.increment(curCount); 124 count += curCount; 125 } 126 127 return new PairDistribution<>(count,jointCount,aCount,bCount); 128 } 129 130}