001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import org.tribuo.util.infotheory.WeightedInformationTheory; 020 021import java.util.HashMap; 022import java.util.List; 023import java.util.Map; 024import java.util.Map.Entry; 025 026/** 027 * Generates the counts for a triplet of vectors. Contains the joint 028 * count, the three pairwise counts, and the three marginal counts. 029 * @param <T1> Type of the first list. 030 * @param <T2> Type of the second list. 031 * @param <T3> Type of the third list. 032 */ 033public class WeightedTripleDistribution<T1,T2,T3> { 034 public static final int DEFAULT_MAP_SIZE = 20; 035 036 public final long count; 037 038 private final Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount; 039 private final Map<CachedPair<T1,T2>,WeightCountTuple> abCount; 040 private final Map<CachedPair<T1,T3>,WeightCountTuple> acCount; 041 private final Map<CachedPair<T2,T3>,WeightCountTuple> bcCount; 042 private final Map<T1,WeightCountTuple> aCount; 043 private final Map<T2,WeightCountTuple> bCount; 044 private final Map<T3,WeightCountTuple> cCount; 045 046 public WeightedTripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount, Map<CachedPair<T1,T2>,WeightCountTuple> abCount, Map<CachedPair<T1,T3>,WeightCountTuple> acCount, Map<CachedPair<T2,T3>,WeightCountTuple> bcCount, Map<T1,WeightCountTuple> aCount, Map<T2,WeightCountTuple> bCount, Map<T3,WeightCountTuple> cCount) { 047 this.count = count; 048 this.jointCount = jointCount; 049 this.abCount = abCount; 050 this.acCount = acCount; 051 this.bcCount = bcCount; 052 this.aCount = aCount; 053 this.bCount = bCount; 054 this.cCount = cCount; 055 } 056 057 public Map<CachedTriple<T1,T2,T3>,WeightCountTuple> getJointCount() { 058 return jointCount; 059 } 060 061 public Map<CachedPair<T1,T2>,WeightCountTuple> getABCount() { 062 return abCount; 063 } 064 065 public Map<CachedPair<T1,T3>,WeightCountTuple> getACCount() { 066 return acCount; 067 } 068 069 public Map<CachedPair<T2,T3>,WeightCountTuple> getBCCount() { 070 return bcCount; 071 } 072 073 public Map<T1,WeightCountTuple> getACount() { 074 return aCount; 075 } 076 077 public Map<T2,WeightCountTuple> getBCount() { 078 return bCount; 079 } 080 081 public Map<T3,WeightCountTuple> getCCount() { 082 return cCount; 083 } 084 085 public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third, List<Double> weights) { 086 Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = new HashMap<>(DEFAULT_MAP_SIZE); 087 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 088 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 089 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 090 Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 091 Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 092 Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 093 094 long count = first.size(); 095 096 if ((first.size() == second.size()) && (first.size() == third.size()) && (first.size() == weights.size())) { 097 for (int i = 0; i < first.size(); i++) { 098 double weight = weights.get(i); 099 T1 a = first.get(i); 100 T2 b = second.get(i); 101 T3 c = third.get(i); 102 CachedTriple<T1,T2,T3> triple = new CachedTriple<>(a,b,c); 103 CachedPair<T1,T2> abPair = triple.getAB(); 104 CachedPair<T1,T3> acPair = triple.getAC(); 105 CachedPair<T2,T3> bcPair = triple.getBC(); 106 107 WeightCountTuple abcCurCount = jointCount.computeIfAbsent(triple,(k) -> new WeightCountTuple()); 108 abcCurCount.weight += weight; 109 abcCurCount.count++; 110 111 WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple()); 112 abCurCount.weight += weight; 113 abCurCount.count++; 114 115 WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple()); 116 acCurCount.weight += weight; 117 acCurCount.count++; 118 119 WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple()); 120 bcCurCount.weight += weight; 121 bcCurCount.count++; 122 123 WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple()); 124 aCurCount.weight += weight; 125 aCurCount.count++; 126 127 WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple()); 128 bCurCount.weight += weight; 129 bCurCount.count++; 130 131 WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple()); 132 cCurCount.weight += weight; 133 cCurCount.count++; 134 } 135 136 WeightedInformationTheory.normaliseWeights(jointCount); 137 WeightedInformationTheory.normaliseWeights(abCount); 138 WeightedInformationTheory.normaliseWeights(acCount); 139 WeightedInformationTheory.normaliseWeights(bcCount); 140 WeightedInformationTheory.normaliseWeights(aCount); 141 WeightedInformationTheory.normaliseWeights(bCount); 142 WeightedInformationTheory.normaliseWeights(cCount); 143 144 return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 145 } else { 146 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size() + ", weights.size() = " + weights.size()); 147 } 148 } 149 150 public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount) { 151 Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE); 152 Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE); 153 Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE); 154 Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE); 155 Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE); 156 Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE); 157 158 long count = 0L; 159 160 for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) { 161 CachedTriple<T1,T2,T3> triple = e.getKey(); 162 WeightCountTuple tuple = e.getValue(); 163 CachedPair<T1,T2> abPair = triple.getAB(); 164 CachedPair<T1,T3> acPair = triple.getAC(); 165 CachedPair<T2,T3> bcPair = triple.getBC(); 166 T1 a = triple.getA(); 167 T2 b = triple.getB(); 168 T3 c = triple.getC(); 169 170 count += tuple.count; 171 172 double weight = tuple.weight * tuple.count; 173 174 WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple()); 175 abCurCount.weight += weight; 176 abCurCount.count += tuple.count; 177 178 WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple()); 179 acCurCount.weight += weight; 180 acCurCount.count += tuple.count; 181 182 WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple()); 183 bcCurCount.weight += weight; 184 bcCurCount.count += tuple.count; 185 186 WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple()); 187 aCurCount.weight += weight; 188 aCurCount.count += tuple.count; 189 190 WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple()); 191 bCurCount.weight += weight; 192 bCurCount.count += tuple.count; 193 194 WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple()); 195 cCurCount.weight += weight; 196 cCurCount.count += tuple.count; 197 } 198 199 WeightedInformationTheory.normaliseWeights(abCount); 200 WeightedInformationTheory.normaliseWeights(acCount); 201 WeightedInformationTheory.normaliseWeights(bcCount); 202 WeightedInformationTheory.normaliseWeights(aCount); 203 WeightedInformationTheory.normaliseWeights(bCount); 204 WeightedInformationTheory.normaliseWeights(cCount); 205 206 return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount); 207 } 208 209}