001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.evaluation;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.clustering.ClusterID;
021import org.tribuo.evaluation.metrics.MetricTarget;
022import org.tribuo.util.infotheory.InformationTheory;
023import org.tribuo.util.infotheory.impl.PairDistribution;
024import org.apache.commons.math3.special.Gamma;
025
026import java.util.List;
027import java.util.Map;
028import java.util.function.BiFunction;
029
030/**
031 * Default metrics for evaluating clusterings.
032 */
033public enum ClusteringMetrics {
034
035    /**
036     * The normalized mutual information between the two clusterings
037     */
038    NORMALIZED_MI((target, context) -> ClusteringMetrics.normalizedMI(context)),
039    /**
040     * The normalized mutual information adjusted for chance.
041     */
042    ADJUSTED_MI((target, context) -> ClusteringMetrics.adjustedMI(context));
043
044    private final BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl;
045
046    ClusteringMetrics(BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl) {
047        this.impl = impl;
048    }
049
050    public BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> getImpl() {
051        return impl;
052    }
053
054    public ClusteringMetric forTarget(MetricTarget<ClusterID> tgt) {
055        return new ClusteringMetric(tgt, this.name(), this.getImpl());
056    }
057
058    /**
059     * Calculates the adjusted normalized mutual information between two clusterings.
060     * @param context The context containing the predicted clustering and the ground truth.
061     * @return The adjusted normalized mutual information.
062     */
063    public static double adjustedMI(ClusteringMetric.Context context) {
064        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
065        double predEntropy = InformationTheory.entropy(context.getPredictedIDs());
066        double trueEntropy = InformationTheory.entropy(context.getTrueIDs());
067        double expectedMI = expectedMI(context.getPredictedIDs(), context.getTrueIDs());
068
069        double minEntropy = Math.min(predEntropy, trueEntropy);
070
071        return (mi - expectedMI) / (minEntropy - expectedMI);
072    }
073
074    /**
075     * Calculates the normalized mutual information between two clusterings.
076     * @param context The context containing the predicted clustering and the ground truth.
077     * @return The normalized mutual information.
078     */
079    public static double normalizedMI(ClusteringMetric.Context context) {
080        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
081        double predEntropy = InformationTheory.entropy(context.getPredictedIDs());
082        double trueEntropy = InformationTheory.entropy(context.getTrueIDs());
083
084        return predEntropy < trueEntropy ? mi / predEntropy : mi / trueEntropy;
085    }
086
087    private static double expectedMI(List<Integer> first, List<Integer> second) {
088        PairDistribution<Integer,Integer> pd = PairDistribution.constructFromLists(first,second);
089
090        Map<Integer, MutableLong> firstCount = pd.firstCount;
091        Map<Integer,MutableLong> secondCount = pd.secondCount;
092        long count = pd.count;
093
094        double output = 0.0;
095
096        for (Map.Entry<Integer,MutableLong> f : firstCount.entrySet()) {
097            for (Map.Entry<Integer,MutableLong> s : secondCount.entrySet()) {
098                long fVal = f.getValue().longValue();
099                long sVal = s.getValue().longValue();
100                long minCount = Math.min(fVal, sVal);
101
102                long threshold = fVal + sVal - count;
103                long start = threshold > 1 ? threshold : 1;
104
105                for (long nij = start; nij < minCount; nij++) {
106                    double acc = ((double) nij) / count;
107                    acc *= Math.log(((double) (count * nij)) / (fVal * sVal));
108                    //numerator
109                    double logSpace = Gamma.logGamma(fVal + 1);
110                    logSpace += Gamma.logGamma(sVal + 1);
111                    logSpace += Gamma.logGamma(count - fVal + 1);
112                    logSpace += Gamma.logGamma(count - sVal + 1);
113                    //denominator
114                    logSpace -= Gamma.logGamma(count + 1);
115                    logSpace -= Gamma.logGamma(nij + 1);
116                    logSpace -= Gamma.logGamma(fVal - nij + 1);
117                    logSpace -= Gamma.logGamma(sVal - nij + 1);
118                    logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1);
119                    acc *= Math.exp(logSpace);
120                    output += acc;
121                }
122            }
123        }
124        return output;
125    }
126
127}