001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import org.tribuo.util.infotheory.InformationTheory;
020import org.tribuo.util.infotheory.WeightedInformationTheory;
021
022import java.util.LinkedHashMap;
023import java.util.List;
024import java.util.Map;
025import java.util.Map.Entry;
026
027/**
028 * Generates the counts for a pair of vectors. Contains the joint
029 * count and the two marginal counts.
030 * @param <T1> Type of the first list.
031 * @param <T2> Type of the second list.
032 */
033public class WeightedPairDistribution<T1,T2> {
034
035    public final long count;
036
037    private final Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts;
038    private final Map<T1,WeightCountTuple> firstCount;
039    private final Map<T2,WeightCountTuple> secondCount;
040
041    public WeightedPairDistribution(long count, Map<CachedPair<T1,T2>,WeightCountTuple> jointCounts, Map<T1,WeightCountTuple> firstCount, Map<T2,WeightCountTuple> secondCount) {
042        this.count = count;
043        this.jointCounts = new LinkedHashMap<>(jointCounts);
044        this.firstCount = new LinkedHashMap<>(firstCount);
045        this.secondCount = new LinkedHashMap<>(secondCount);
046    }
047
048    public WeightedPairDistribution(long count, LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> jointCounts, LinkedHashMap<T1,WeightCountTuple> firstCount, LinkedHashMap<T2,WeightCountTuple> secondCount) {
049        this.count = count;
050        this.jointCounts = jointCounts;
051        this.firstCount = firstCount;
052        this.secondCount = secondCount;
053    }
054    
055    public Map<CachedPair<T1,T2>,WeightCountTuple> getJointCounts() {
056        return jointCounts;
057    }
058    
059    public Map<T1,WeightCountTuple> getFirstCount() {
060        return firstCount;
061    }
062    
063    public Map<T2,WeightCountTuple> getSecondCount() {
064        return secondCount;
065    }
066    
067    /**
068     * Generates the counts for two vectors. Returns a pair containing the joint
069     * count, and a pair of the two marginal counts.
070     * @param <T1> Type of the first list.
071     * @param <T2> Type of the second list.
072     * @param first An list of values.
073     * @param second Another list of values.
074     * @param weights An list of per example weights.
075     * @return A WeightedPairDistribution.
076     */
077    public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromLists(List<T1> first, List<T2> second, List<Double> weights) {
078        LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
079        LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
080        LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
081
082        if ((first.size() == second.size()) && (first.size() == weights.size())) {
083            long count = 0;
084            for (int i = 0; i < first.size(); i++) {
085                T1 a = first.get(i);
086                T2 b = second.get(i);
087                double weight = weights.get(i);
088                CachedPair<T1,T2> pair = new CachedPair<>(a,b);
089
090                WeightCountTuple abCurCount = countDist.computeIfAbsent(pair,(k) -> new WeightCountTuple());
091                abCurCount.weight += weight;
092                abCurCount.count++;
093
094                WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple());
095                aCurCount.weight += weight;
096                aCurCount.count++;
097
098                WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple());
099                bCurCount.weight += weight;
100                bCurCount.count++;
101
102                count++;
103            }
104
105            WeightedInformationTheory.normaliseWeights(countDist);
106            WeightedInformationTheory.normaliseWeights(aCountDist);
107            WeightedInformationTheory.normaliseWeights(bCountDist);
108
109            return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist);
110        } else {
111            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
112        }
113    }
114
115    /**
116     * Generates a WeightedPairDistribution by generating the marginal distributions for the first and second elements.
117     * This assumes the weights have already been normalised.
118     * @param <T1> Type of the first element.
119     * @param <T2> Type of the second element.
120     * @param jointCount The (normalised) input map.
121     * @return A WeightedPairDistribution
122     */
123    public static <T1,T2> WeightedPairDistribution<T1,T2> constructFromMap(Map<CachedPair<T1,T2>,WeightCountTuple> jointCount) {
124        LinkedHashMap<CachedPair<T1,T2>,WeightCountTuple> countDist = new LinkedHashMap<>(jointCount);
125        LinkedHashMap<T1,WeightCountTuple> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
126        LinkedHashMap<T2,WeightCountTuple> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
127
128        long count = 0L;
129        
130        for (Entry<CachedPair<T1,T2>,WeightCountTuple> e : countDist.entrySet()) {
131            CachedPair<T1,T2> pair = e.getKey();
132            WeightCountTuple tuple = e.getValue();
133            T1 a = pair.getA();
134            T2 b = pair.getB();
135            double weight = tuple.weight * tuple.count;
136
137            WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a,(k) -> new WeightCountTuple());
138            aCurCount.weight += weight;
139            aCurCount.count += tuple.count;
140
141            WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b,(k) -> new WeightCountTuple());
142            bCurCount.weight += weight;
143            bCurCount.count += tuple.count;
144
145            count += tuple.count;
146        }
147
148        WeightedInformationTheory.normaliseWeights(aCountDist);
149        WeightedInformationTheory.normaliseWeights(bCountDist);
150
151        return new WeightedPairDistribution<>(count,countDist,aCountDist,bCountDist);
152    }
153
154}