001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.util.infotheory.InformationTheory;
021
022import java.util.HashMap;
023import java.util.LinkedHashMap;
024import java.util.List;
025import java.util.Map;
026import java.util.Map.Entry;
027
028/**
029 * A count distribution over {@link CachedPair} objects.
030 * @param <T1> The type of the first element
031 * @param <T2> The type of the second element
032 */
033public class PairDistribution<T1,T2> {
034
035    public final long count;
036
037    public final Map<CachedPair<T1,T2>,MutableLong> jointCounts;
038    public final Map<T1,MutableLong> firstCount;
039    public final Map<T2,MutableLong> secondCount;
040
041    public PairDistribution(long count, Map<CachedPair<T1,T2>,MutableLong> jointCounts, Map<T1,MutableLong> firstCount, Map<T2,MutableLong> secondCount) {
042        this.count = count;
043        this.jointCounts = new LinkedHashMap<>(jointCounts);
044        this.firstCount = new LinkedHashMap<>(firstCount);
045        this.secondCount = new LinkedHashMap<>(secondCount);
046    }
047
048    public PairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,MutableLong> jointCounts, LinkedHashMap<T1,MutableLong> firstCount, LinkedHashMap<T2,MutableLong> secondCount) {
049        this.count = count;
050        this.jointCounts = jointCounts;
051        this.firstCount = firstCount;
052        this.secondCount = secondCount;
053    }
054    
055    /**
056     * Generates the counts for two vectors. Returns a PairDistribution containing the joint
057     * count, and the two marginal counts.
058     * @param <T1> Type of the first array.
059     * @param <T2> Type of the second array.
060     * @param first An array of values.
061     * @param second Another array of values.
062     * @return The joint counts and the two marginal counts.
063     */
064    public static <T1,T2> PairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second) {
065        LinkedHashMap<CachedPair<T1,T2>,MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
066        LinkedHashMap<T1,MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
067        LinkedHashMap<T2,MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
068
069        if (first.size() == second.size()) {
070            long count = 0;
071            for (int i = 0; i < first.size(); i++) {
072                T1 a = first.get(i);
073                T2 b = second.get(i);
074                CachedPair<T1,T2> pair = new CachedPair<>(a,b);
075
076                MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong());
077                abCount.increment();
078
079                MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong());
080                aCount.increment();
081
082                MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong());
083                bCount.increment();
084
085                count++;
086            }
087
088            return new PairDistribution<>(count,abCountDist,aCountDist,bCountDist);
089        } else {
090            throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
091        }
092    }
093
094    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount) {
095        Map<T1,MutableLong> aCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
096        Map<T2,MutableLong> bCount = new HashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
097
098        return constructFromMap(jointCount,aCount,bCount);
099    }
100
101    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount, int aSize, int bSize) {
102        Map<T1,MutableLong> aCount = new HashMap<>(aSize);
103        Map<T2,MutableLong> bCount = new HashMap<>(bSize);
104
105        return constructFromMap(jointCount,aCount,bCount);
106    }
107
108    public static <T1,T2> PairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,MutableLong> jointCount,
109                                                                           Map<T1,MutableLong> aCount,
110                                                                           Map<T2,MutableLong> bCount) {
111        long count = 0L;
112        
113        for (Entry<CachedPair<T1,T2>,MutableLong> e : jointCount.entrySet()) {
114            CachedPair<T1,T2> pair = e.getKey();
115            long curCount = e.getValue().longValue();
116            T1 a = pair.getA();
117            T2 b = pair.getB();
118
119            MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong());
120            curACount.increment(curCount);
121
122            MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong());
123            curBCount.increment(curCount);
124            count += curCount;
125        }
126
127        return new PairDistribution<>(count,jointCount,aCount,bCount);
128    }
129
130}