001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020 021import java.util.HashMap; 022import java.util.LinkedHashMap; 023import java.util.List; 024import java.util.Map; 025import java.util.Map.Entry; 026 027/** 028 * Generates the counts for a triplet of vectors. Contains the joint 029 * count, the three pairwise counts, and the three marginal counts. 030 * @param <T1> Type of the first list. 031 * @param <T2> Type of the second list. 032 * @param <T3> Type of the third list. 033 */ 034public class TripleDistribution<T1,T2,T3> { 035 public static final int DEFAULT_MAP_SIZE = 20; 036 037 public final long count; 038 039 private final Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount; 040 private final Map<CachedPair<T1,T2>,MutableLong> abCount; 041 private final Map<CachedPair<T1,T3>,MutableLong> acCount; 042 private final Map<CachedPair<T2,T3>,MutableLong> bcCount; 043 private final Map<T1,MutableLong> aCount; 044 private final Map<T2,MutableLong> bCount; 045 private final Map<T3,MutableLong> cCount; 046 047 public TripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, Map<CachedPair<T1,T2>,MutableLong> abCount, Map<CachedPair<T1,T3>,MutableLong> acCount, Map<CachedPair<T2,T3>,MutableLong> bcCount, Map<T1,MutableLong> aCount, Map<T2,MutableLong> bCount, Map<T3,MutableLong> cCount) { 048 this.count = count; 049 this.jointCount = jointCount; 050 this.abCount = abCount; 051 this.acCount = acCount; 052 this.bcCount = bcCount; 053 this.aCount = aCount; 054 this.bCount = bCount; 055 this.cCount = cCount; 056 } 057 058 public Map<CachedTriple<T1,T2,T3>,MutableLong> getJointCount() { 059 return jointCount; 060 } 061 062 public Map<CachedPair<T1,T2>,MutableLong> getABCount() { 063 return abCount; 064 } 065 066 public Map<CachedPair<T1,T3>,MutableLong> getACCount() { 067 return acCount; 068 } 069 070 public Map<CachedPair<T2,T3>,MutableLong> getBCCount() { 071 return bcCount; 072 } 073 074 public Map<T1,MutableLong> getACount() { 075 return aCount; 076 } 077 078 public Map<T2,MutableLong> getBCount() { 079 return bCount; 080 } 081 082 public Map<T3,MutableLong> getCCount() { 083 return cCount; 084 } 085 086 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) { 087 Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE); 088 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 089 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 090 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 091 Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 092 Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 093 Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 094 095 long count = first.size(); 096 097 if ((first.size() == second.size()) && (first.size() == third.size())) { 098 for (int i = 0; i < first.size(); i++) { 099 T1 a = first.get(i); 100 T2 b = second.get(i); 101 T3 c = third.get(i); 102 CachedTriple<T1,T2,T3> abc = new CachedTriple<>(a,b,c); 103 CachedPair<T1,T2> ab = abc.getAB(); 104 CachedPair<T1,T3> ac = abc.getAC(); 105 CachedPair<T2,T3> bc = abc.getBC(); 106 107 MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong()); 108 abcCurCount.increment(); 109 110 MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong()); 111 abCurCount.increment(); 112 MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong()); 113 acCurCount.increment(); 114 MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong()); 115 bcCurCount.increment(); 116 117 MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong()); 118 aCurCount.increment(); 119 MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 120 bCurCount.increment(); 121 MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong()); 122 cCurCount.increment(); 123 } 124 125 return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 126 } else { 127 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size()); 128 } 129 } 130 131 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount) { 132 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 133 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 134 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 135 Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 136 Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 137 Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 138 139 return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 140 } 141 142 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, 143 int abSize, int acSize, int bcSize, 144 int aSize, int bSize, int cSize) { 145 Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(abSize); 146 Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(acSize); 147 Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(bcSize); 148 Map<T1,MutableLong> aCount = new HashMap<>(aSize); 149 Map<T2,MutableLong> bCount = new HashMap<>(bSize); 150 Map<T3,MutableLong> cCount = new HashMap<>(cSize); 151 152 return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 153 } 154 155 public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, 156 Map<CachedPair<T1,T2>,MutableLong> abCount, 157 Map<CachedPair<T1,T3>,MutableLong> acCount, 158 Map<CachedPair<T2,T3>,MutableLong> bcCount, 159 Map<T1,MutableLong> aCount, 160 Map<T2,MutableLong> bCount, 161 Map<T3,MutableLong> cCount) { 162 long count = 0L; 163 164 for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) { 165 CachedTriple<T1,T2,T3> abc = e.getKey(); 166 long curCount = e.getValue().longValue(); 167 CachedPair<T1,T2> ab = abc.getAB(); 168 CachedPair<T1,T3> ac = abc.getAC(); 169 CachedPair<T2,T3> bc = abc.getBC(); 170 T1 a = abc.getA(); 171 T2 b = abc.getB(); 172 T3 c = abc.getC(); 173 174 count += curCount; 175 176 MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong()); 177 abCurCount.increment(curCount); 178 MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong()); 179 acCurCount.increment(curCount); 180 MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong()); 181 bcCurCount.increment(curCount); 182 183 MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong()); 184 aCurCount.increment(curCount); 185 MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong()); 186 bCurCount.increment(curCount); 187 MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong()); 188 cCurCount.increment(curCount); 189 } 190 191 return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 192 } 193 194}