001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util.infotheory.impl; 018 019import org.tribuo.util.infotheory.InformationTheory; 020import org.tribuo.util.infotheory.WeightedInformationTheory; 021 022import java.util.LinkedHashMap; 023import java.util.List; 024import java.util.Map; 025import java.util.Map.Entry; 026 027/** 028 * Generates the counts for a pair of vectors. Contains the joint 029 * count and the two marginal counts. 030 * @param <T1> Type of the first list. 031 * @param <T2> Type of the second list. 032 */ 033public class WeightedPairDistribution<T1,T2> { 034 035 public final long count; 036 037 private final Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts; 038 private final Map<T1,WeightCountTuple> firstCount; 039 private final Map<T2,WeightCountTuple> secondCount; 040 041 public WeightedPairDistribution(long count, Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts, Map<T1,WeightCountTuple> firstCount, Map<T2,WeightCountTuple> secondCount) { 042 this.count = count; 043 this.jointCounts = new LinkedHashMap<>(jointCounts); 044 this.firstCount = new LinkedHashMap<>(firstCount); 045 this.secondCount = new LinkedHashMap<>(secondCount); 046 } 047 048 public WeightedPairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> jointCounts, LinkedHashMap<T1,WeightCountTuple> firstCount, LinkedHashMap<T2,WeightCountTuple> secondCount) { 049 this.count = count; 050 this.jointCounts = jointCounts; 051 this.firstCount = firstCount; 052 this.secondCount = secondCount; 053 } 054 055 public Map<CachedPair<T1,T2>,WeightCountTuple> getJointCounts() { 056 return jointCounts; 057 } 058 059 public Map<T1,WeightCountTuple> getFirstCount() { 060 return firstCount; 061 } 062 063 public Map<T2,WeightCountTuple> getSecondCount() { 064 return secondCount; 065 } 066 067 /** 068 * Generates the counts for two vectors. Returns a pair containing the joint 069 * count, and a pair of the two marginal counts. 070 * @param <T1> Type of the first list. 071 * @param <T2> Type of the second list. 072 * @param first An list of values. 073 * @param second Another list of values. 074 * @param weights An list of per example weights. 075 * @return A WeightedPairDistribution. 076 */ 077 public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second, List<Double> weights) { 078 LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 079 LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 080 LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 081 082 if ((first.size() == second.size()) && (first.size() == weights.size())) { 083 long count = 0; 084 for (int i = 0; i < first.size(); i++) { 085 T1 a = first.get(i); 086 T2 b = second.get(i); 087 double weight = weights.get(i); 088 CachedPair<T1,T2> pair = new CachedPair<>(a,b); 089 090 WeightCountTuple abCurCount = countDist.computeIfAbsent(pair,(k) -> new WeightCountTuple()); 091 abCurCount.weight += weight; 092 abCurCount.count++; 093 094 WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple()); 095 aCurCount.weight += weight; 096 aCurCount.count++; 097 098 WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple()); 099 bCurCount.weight += weight; 100 bCurCount.count++; 101 102 count++; 103 } 104 105 WeightedInformationTheory.normaliseWeights(countDist); 106 WeightedInformationTheory.normaliseWeights(aCountDist); 107 WeightedInformationTheory.normaliseWeights(bCountDist); 108 109 return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist); 110 } else { 111 throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size()); 112 } 113 } 114 115 /** 116 * Generates a WeightedPairDistribution by generating the marginal distributions for the first and second elements. 117 * This assumes the weights have already been normalised. 118 * @param <T1> Type of the first element. 119 * @param <T2> Type of the second element. 120 * @param jointCount The (normalised) input map. 121 * @return A WeightedPairDistribution 122 */ 123 public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,WeightCountTuple> jointCount) { 124 LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(jointCount); 125 LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 126 LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE); 127 128 long count = 0L; 129 130 for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) { 131 CachedPair<T1,T2> pair = e.getKey(); 132 WeightCountTuple tuple = e.getValue(); 133 T1 a = pair.getA(); 134 T2 b = pair.getB(); 135 double weight = tuple.weight * tuple.count; 136 137 WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple()); 138 aCurCount.weight += weight; 139 aCurCount.count += tuple.count; 140 141 WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple()); 142 bCurCount.weight += weight; 143 bCurCount.count += tuple.count; 144 145 count += tuple.count; 146 } 147 148 WeightedInformationTheory.normaliseWeights(aCountDist); 149 WeightedInformationTheory.normaliseWeights(bCountDist); 150 151 return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist); 152 } 153 154}