001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020
021import java.util.HashMap;
022import java.util.LinkedHashMap;
023import java.util.List;
024import java.util.Map;
025import java.util.Map.Entry;
026
027/**
028 * Generates the counts for a triplet of vectors. Contains the joint
029 * count, the three pairwise counts, and the three marginal counts.
030 * @param <T1> Type of the first list.
031 * @param <T2> Type of the second list.
032 * @param <T3> Type of the third list.
033 */
034public class TripleDistribution<T1,T2,T3> {
035    public static final int DEFAULT_MAP_SIZE = 20;
036
037    public final long count;
038
039    private final Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount;
040    private final Map<CachedPair<T1,T2>,MutableLong> abCount;
041    private final Map<CachedPair<T1,T3>,MutableLong> acCount;
042    private final Map<CachedPair<T2,T3>,MutableLong> bcCount;
043    private final Map<T1,MutableLong> aCount;
044    private final Map<T2,MutableLong> bCount;
045    private final Map<T3,MutableLong> cCount;
046
047    public TripleDistribution(long count, Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount, Map<CachedPair<T1,T2>,MutableLong> abCount, Map<CachedPair<T1,T3>,MutableLong> acCount, Map<CachedPair<T2,T3>,MutableLong> bcCount, Map<T1,MutableLong> aCount, Map<T2,MutableLong> bCount, Map<T3,MutableLong> cCount) {
048        this.count = count;
049        this.jointCount = jointCount;
050        this.abCount = abCount;
051        this.acCount = acCount;
052        this.bcCount = bcCount;
053        this.aCount = aCount;
054        this.bCount = bCount;
055        this.cCount = cCount;
056    }
057
058    public Map<CachedTriple<T1,T2,T3>,MutableLong> getJointCount() {
059        return jointCount;
060    }
061    
062    public Map<CachedPair<T1,T2>,MutableLong> getABCount() {
063        return abCount;
064    }
065    
066    public Map<CachedPair<T1,T3>,MutableLong> getACCount() {
067        return acCount;
068    }
069    
070    public Map<CachedPair<T2,T3>,MutableLong> getBCCount() {
071        return bcCount;
072    }
073    
074    public Map<T1,MutableLong> getACount() {
075        return aCount;
076    }
077    
078    public Map<T2,MutableLong> getBCount() {
079        return bCount;
080    }
081    
082    public Map<T3,MutableLong> getCCount() {
083        return cCount;
084    }
085    
086    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) {
087        Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
088        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
089        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
090        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
091        Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
092        Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
093        Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
094
095        long count = first.size();
096
097        if ((first.size() == second.size()) && (first.size() == third.size())) {
098            for (int i = 0; i < first.size(); i++) {
099                T1 a = first.get(i);
100                T2 b = second.get(i);
101                T3 c = third.get(i);
102                CachedTriple<T1,T2,T3> abc = new CachedTriple<>(a,b,c);
103                CachedPair<T1,T2> ab = abc.getAB();
104                CachedPair<T1,T3> ac = abc.getAC();
105                CachedPair<T2,T3> bc = abc.getBC();
106
107                MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong());
108                abcCurCount.increment();
109
110                MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
111                abCurCount.increment();
112                MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
113                acCurCount.increment();
114                MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
115                bcCurCount.increment();
116
117                MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
118                aCurCount.increment();
119                MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
120                bCurCount.increment();
121                MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
122                cCurCount.increment();
123            }
124
125            return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
126        } else {
127            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size());
128        }
129    }
130
131    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount) {
132        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
133        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
134        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
135        Map<T1,MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
136        Map<T2,MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
137        Map<T3,MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
138
139        return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
140    }
141
142    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount,
143                                                                           int abSize, int acSize, int bcSize,
144                                                                           int aSize, int bSize, int cSize) {
145        Map<CachedPair<T1,T2>,MutableLong> abCount = new HashMap<>(abSize);
146        Map<CachedPair<T1,T3>,MutableLong> acCount = new HashMap<>(acSize);
147        Map<CachedPair<T2,T3>,MutableLong> bcCount = new HashMap<>(bcSize);
148        Map<T1,MutableLong> aCount = new HashMap<>(aSize);
149        Map<T2,MutableLong> bCount = new HashMap<>(bSize);
150        Map<T3,MutableLong> cCount = new HashMap<>(cSize);
151
152        return constructFromMap(jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
153    }
154
155    public static <T1,T2,T3> TripleDistribution<T1,T2,T3> constructFromMap(Map<CachedTriple<T1,T2,T3>,MutableLong> jointCount,
156                                                                           Map<CachedPair<T1,T2>,MutableLong> abCount,
157                                                                           Map<CachedPair<T1,T3>,MutableLong> acCount,
158                                                                           Map<CachedPair<T2,T3>,MutableLong> bcCount,
159                                                                           Map<T1,MutableLong> aCount,
160                                                                           Map<T2,MutableLong> bCount,
161                                                                           Map<T3,MutableLong> cCount) {
162        long count = 0L;
163
164        for (Entry<CachedTriple<T1,T2,T3>,MutableLong> e : jointCount.entrySet()) {
165            CachedTriple<T1,T2,T3> abc = e.getKey();
166            long curCount = e.getValue().longValue();
167            CachedPair<T1,T2> ab = abc.getAB();
168            CachedPair<T1,T3> ac = abc.getAC();
169            CachedPair<T2,T3> bc = abc.getBC();
170            T1 a = abc.getA();
171            T2 b = abc.getB();
172            T3 c = abc.getC();
173
174            count += curCount;
175
176            MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
177            abCurCount.increment(curCount);
178            MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
179            acCurCount.increment(curCount);
180            MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
181            bcCurCount.increment(curCount);
182
183            MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
184            aCurCount.increment(curCount);
185            MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
186            bCurCount.increment(curCount);
187            MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
188            cCurCount.increment(curCount);
189        }
190
191        return new TripleDistribution<>(count,jointCount,abCount,acCount,bcCount,aCount,bCount,cCount);
192    }
193    
194}