001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.evaluation; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.clustering.ClusterID; 021import org.tribuo.evaluation.metrics.MetricTarget; 022import org.tribuo.util.infotheory.InformationTheory; 023import org.tribuo.util.infotheory.impl.PairDistribution; 024import org.apache.commons.math3.special.Gamma; 025 026import java.util.List; 027import java.util.Map; 028import java.util.function.BiFunction; 029 030/** 031 * Default metrics for evaluating clusterings. 032 */ 033public enum ClusteringMetrics { 034 035 /** 036 * The normalized mutual information between the two clusterings 037 */ 038 NORMALIZED_MI((target, context) -> ClusteringMetrics.normalizedMI(context)), 039 /** 040 * The normalized mutual information adjusted for chance. 041 */ 042 ADJUSTED_MI((target, context) -> ClusteringMetrics.adjustedMI(context)); 043 044 private final BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl; 045 046 ClusteringMetrics(BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl) { 047 this.impl = impl; 048 } 049 050 public BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> getImpl() { 051 return impl; 052 } 053 054 public ClusteringMetric forTarget(MetricTarget<ClusterID> tgt) { 055 return new ClusteringMetric(tgt, this.name(), this.getImpl()); 056 } 057 058 /** 059 * Calculates the adjusted normalized mutual information between two clusterings. 060 * @param context The context containing the predicted clustering and the ground truth. 061 * @return The adjusted normalized mutual information. 062 */ 063 public static double adjustedMI(ClusteringMetric.Context context) { 064 double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs()); 065 double predEntropy = InformationTheory.entropy(context.getPredictedIDs()); 066 double trueEntropy = InformationTheory.entropy(context.getTrueIDs()); 067 double expectedMI = expectedMI(context.getPredictedIDs(), context.getTrueIDs()); 068 069 double minEntropy = Math.min(predEntropy, trueEntropy); 070 071 return (mi - expectedMI) / (minEntropy - expectedMI); 072 } 073 074 /** 075 * Calculates the normalized mutual information between two clusterings. 076 * @param context The context containing the predicted clustering and the ground truth. 077 * @return The normalized mutual information. 078 */ 079 public static double normalizedMI(ClusteringMetric.Context context) { 080 double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs()); 081 double predEntropy = InformationTheory.entropy(context.getPredictedIDs()); 082 double trueEntropy = InformationTheory.entropy(context.getTrueIDs()); 083 084 return predEntropy < trueEntropy ? mi / predEntropy : mi / trueEntropy; 085 } 086 087 private static double expectedMI(List<Integer> first, List<Integer> second) { 088 PairDistribution<Integer,Integer> pd = PairDistribution.constructFromLists(first,second); 089 090 Map<Integer, MutableLong> firstCount = pd.firstCount; 091 Map<Integer,MutableLong> secondCount = pd.secondCount; 092 long count = pd.count; 093 094 double output = 0.0; 095 096 for (Map.Entry<Integer,MutableLong> f : firstCount.entrySet()) { 097 for (Map.Entry<Integer,MutableLong> s : secondCount.entrySet()) { 098 long fVal = f.getValue().longValue(); 099 long sVal = s.getValue().longValue(); 100 long minCount = Math.min(fVal, sVal); 101 102 long threshold = fVal + sVal - count; 103 long start = threshold > 1 ? threshold : 1; 104 105 for (long nij = start; nij < minCount; nij++) { 106 double acc = ((double) nij) / count; 107 acc *= Math.log(((double) (count * nij)) / (fVal * sVal)); 108 //numerator 109 double logSpace = Gamma.logGamma(fVal + 1); 110 logSpace += Gamma.logGamma(sVal + 1); 111 logSpace += Gamma.logGamma(count - fVal + 1); 112 logSpace += Gamma.logGamma(count - sVal + 1); 113 //denominator 114 logSpace -= Gamma.logGamma(count + 1); 115 logSpace -= Gamma.logGamma(nij + 1); 116 logSpace -= Gamma.logGamma(fVal - nij + 1); 117 logSpace -= Gamma.logGamma(sVal - nij + 1); 118 logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1); 119 acc *= Math.exp(logSpace); 120 output += acc; 121 } 122 } 123 } 124 return output; 125 } 126 127}