001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import org.tribuo.util.infotheory.WeightedInformationTheory;
020
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024import java.util.Map.Entry;
025
026/**
027 * Generates the counts for a triplet of vectors. Contains the joint
028 * count, the three pairwise counts, and the three marginal counts.
029 * @param <T1> Type of the first list.
030 * @param <T2> Type of the second list.
031 * @param <T3> Type of the third list.
032 */
033public class WeightedTripleDistribution<T1,T2,T3> {
034    public static final int DEFAULT_MAP_SIZE = 20;
035
036    public final long count;
037
038    private final Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount;
039    private final Map<CachedPair<T1,T2>,WeightCountTuple> abCount;
040    private final Map<CachedPair<T1,T3>,WeightCountTuple> acCount;
041    private final Map<CachedPair<T2,T3>,WeightCountTuple> bcCount;
042    private final Map<T1,WeightCountTuple> aCount;
043    private final Map<T2,WeightCountTuple> bCount;
044    private final Map<T3,WeightCountTuple> cCount;
045
046    public WeightedTripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount, Map<CachedPair<T1,T2>,WeightCountTuple> abCount, Map<CachedPair<T1,T3>,WeightCountTuple> acCount, Map<CachedPair<T2,T3>,WeightCountTuple> bcCount, Map<T1,WeightCountTuple> aCount, Map<T2,WeightCountTuple> bCount, Map<T3,WeightCountTuple> cCount) {
047        this.count = count;
048        this.jointCount = jointCount;
049        this.abCount = abCount;
050        this.acCount = acCount;
051        this.bcCount = bcCount;
052        this.aCount = aCount;
053        this.bCount = bCount;
054        this.cCount = cCount;
055    }
056
057    public Map<CachedTriple<T1,T2,T3>,WeightCountTuple> getJointCount() {
058        return jointCount;
059    }
060    
061    public Map<CachedPair<T1,T2>,WeightCountTuple> getABCount() {
062        return abCount;
063    }
064    
065    public Map<CachedPair<T1,T3>,WeightCountTuple> getACCount() {
066        return acCount;
067    }
068    
069    public Map<CachedPair<T2,T3>,WeightCountTuple> getBCCount() {
070        return bcCount;
071    }
072    
073    public Map<T1,WeightCountTuple> getACount() {
074        return aCount;
075    }
076    
077    public Map<T2,WeightCountTuple> getBCount() {
078        return bCount;
079    }
080    
081    public Map<T3,WeightCountTuple> getCCount() {
082        return cCount;
083    }
084    
085    public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third, List<Double> weights) {
086        Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount = new HashMap<>(DEFAULT_MAP_SIZE);
087        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
088        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
089        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
090        Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
091        Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
092        Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
093
094        long count = first.size();
095
096        if ((first.size() == second.size()) && (first.size() == third.size()) && (first.size() == weights.size())) {
097            for (int i = 0; i < first.size(); i++) {
098                double weight = weights.get(i);
099                T1 a = first.get(i);
100                T2 b = second.get(i);
101                T3 c = third.get(i);
102                CachedTriple<T1,T2,T3> triple = new CachedTriple<>(a,b,c);
103                CachedPair<T1,T2> abPair = triple.getAB();
104                CachedPair<T1,T3> acPair = triple.getAC();
105                CachedPair<T2,T3> bcPair = triple.getBC();
106
107                WeightCountTuple abcCurCount = jointCount.computeIfAbsent(triple,(k) -> new WeightCountTuple());
108                abcCurCount.weight += weight;
109                abcCurCount.count++;
110
111                WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple());
112                abCurCount.weight += weight;
113                abCurCount.count++;
114
115                WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple());
116                acCurCount.weight += weight;
117                acCurCount.count++;
118                
119                WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple());
120                bcCurCount.weight += weight;
121                bcCurCount.count++;
122                
123                WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple());
124                aCurCount.weight += weight;
125                aCurCount.count++;
126
127                WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple());
128                bCurCount.weight += weight;
129                bCurCount.count++;
130
131                WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple());
132                cCurCount.weight += weight;
133                cCurCount.count++;
134            }
135
136            WeightedInformationTheory.normaliseWeights(jointCount);
137            WeightedInformationTheory.normaliseWeights(abCount);
138            WeightedInformationTheory.normaliseWeights(acCount);
139            WeightedInformationTheory.normaliseWeights(bcCount);
140            WeightedInformationTheory.normaliseWeights(aCount);
141            WeightedInformationTheory.normaliseWeights(bCount);
142            WeightedInformationTheory.normaliseWeights(cCount);
143
144            return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
145        } else {
146            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size() + ", weights.size() = " + weights.size());
147        }
148    }
149
150    public static <T1,T2,T3> WeightedTripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,WeightCountTuple> jointCount) {
151        Map<CachedPair<T1,T2>,WeightCountTuple> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
152        Map<CachedPair<T1,T3>,WeightCountTuple> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
153        Map<CachedPair<T2,T3>,WeightCountTuple> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
154        Map<T1,WeightCountTuple> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
155        Map<T2,WeightCountTuple> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
156        Map<T3,WeightCountTuple> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
157        
158        long count = 0L;
159
160        for (Entry<CachedTriple<T1,T2,T3>,WeightCountTuple> e : jointCount.entrySet()) {
161            CachedTriple<T1,T2,T3> triple = e.getKey();
162            WeightCountTuple tuple = e.getValue();
163            CachedPair<T1,T2> abPair = triple.getAB();
164            CachedPair<T1,T3> acPair = triple.getAC();
165            CachedPair<T2,T3> bcPair = triple.getBC();
166            T1 a = triple.getA();
167            T2 b = triple.getB();
168            T3 c = triple.getC();
169
170            count += tuple.count;
171
172            double weight = tuple.weight * tuple.count;
173
174            WeightCountTuple abCurCount = abCount.computeIfAbsent(abPair,(k) -> new WeightCountTuple());
175            abCurCount.weight += weight;
176            abCurCount.count += tuple.count;
177
178            WeightCountTuple acCurCount = acCount.computeIfAbsent(acPair,(k) -> new WeightCountTuple());
179            acCurCount.weight += weight;
180            acCurCount.count += tuple.count;
181
182            WeightCountTuple bcCurCount = bcCount.computeIfAbsent(bcPair,(k) -> new WeightCountTuple());
183            bcCurCount.weight += weight;
184            bcCurCount.count += tuple.count;
185
186            WeightCountTuple aCurCount = aCount.computeIfAbsent(a,(k) -> new WeightCountTuple());
187            aCurCount.weight += weight;
188            aCurCount.count += tuple.count;
189
190            WeightCountTuple bCurCount = bCount.computeIfAbsent(b,(k) -> new WeightCountTuple());
191            bCurCount.weight += weight;
192            bCurCount.count += tuple.count;
193
194            WeightCountTuple cCurCount = cCount.computeIfAbsent(c,(k) -> new WeightCountTuple());
195            cCurCount.weight += weight;
196            cCurCount.count += tuple.count;
197        }
198
199        WeightedInformationTheory.normaliseWeights(abCount);
200        WeightedInformationTheory.normaliseWeights(acCount);
201        WeightedInformationTheory.normaliseWeights(bcCount);
202        WeightedInformationTheory.normaliseWeights(aCount);
203        WeightedInformationTheory.normaliseWeights(bCount);
204        WeightedInformationTheory.normaliseWeights(cCount);
205
206        return new WeightedTripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
207    }
208    
209}