001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.evaluation;
018
019import org.tribuo.ImmutableOutputInfo;
020import org.tribuo.Model;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.classification.evaluation.ConfusionMatrix;
024import org.tribuo.math.la.DenseMatrix;
025import org.tribuo.multilabel.MultiLabel;
026import org.tribuo.multilabel.MultiLabelFactory;
027
028import java.util.List;
029import java.util.Set;
030import java.util.function.Function;
031
032/**
033 * A {@link ConfusionMatrix} which accepts {@link MultiLabel}s.
034 *
035 * <p>
036 * In a multi-label confusion matrix M,
037 * <pre>
038 * tn = M[:, 0, 0]
039 * fn = M[:, 1, 0]
040 * tp = M[:, 1, 1]
041 * fp = M[:, 0, 1]
042 * </pre>
043 * <p>
044 * For class-wise values,
045 * <pre>
046 * tn(class i) = M[i, 0, 0]
047 * fn(class i) = M[i, 1, 0]
048 * tp(class i) = M[i, 1, 1]
049 * fp(class i) = M[i, 0, 1]
050 * </pre>
051 */
052public final class MultiLabelConfusionMatrix implements ConfusionMatrix<MultiLabel> {
053
054    private final ImmutableOutputInfo<MultiLabel> domain;
055    private final DenseMatrix[] mcm;
056    private final DenseMatrix confusion;
057
058    public MultiLabelConfusionMatrix(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) {
059        this(model.getOutputIDInfo(), predictions);
060    }
061
062    MultiLabelConfusionMatrix(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {
063        this.domain = domain;
064        ConfusionMatrixTuple tab = tabulate(domain, predictions);
065        this.mcm = tab.mcm;
066        this.confusion = tab.confusion;
067    }
068
069    @Override
070    public double support(MultiLabel cls) {
071        double total = 0d;
072        for (Label label : cls.getLabelSet()) {
073            int ix = getDomain().getID(new MultiLabel(label));
074            /*
075            mcm[i] =
076            [tn, fn]
077            [fp, tp]
078
079            support = false negatives + true positives
080
081            false neg => ground truth was [label] but we predicted something else
082            true pos  => ground truth was [label] and we predicted [label]
083
084            (whereas: false pos => ground truth was NOT [label] but we predicted [label])
085
086            so
087
088            support = false neg + true pos = mcm[i, 0, 1] + mcm[i, 1, 1] = mcm[i, :, 1].sum()
089             */
090            total += mcm[ix].getColumn(1).sum();
091        }
092        return total;
093    }
094
095    @Override
096    public ImmutableOutputInfo<MultiLabel> getDomain() {
097        return domain;
098    }
099
100    @Override
101    public double support() {
102        double total = 0d;
103        for (int i = 0; i < domain.size(); i++) {
104            total += mcm[i].getColumn(1).sum();
105        }
106        return total;
107    }
108
109    @Override
110    public double tp(MultiLabel cls) {
111        return compute(cls, (cm) -> cm.get(1, 1));
112    }
113
114    @Override
115    public double fp(MultiLabel cls) {
116        return compute(cls, (cm) -> cm.get(0, 1));
117    }
118
119    @Override
120    public double fn(MultiLabel cls) {
121        return compute(cls, (cm) -> cm.get(1, 0));
122    }
123
124    @Override
125    public double tn(MultiLabel cls) {
126        return compute(cls, (cm) -> cm.get(0, 0));
127    }
128
129    private double compute(MultiLabel cls, Function<DenseMatrix, Double> getter) {
130        double total = 0d;
131        for (Label label : cls.getLabelSet()) {
132            int i = domain.getID(new MultiLabel(label.getLabel()));
133            //
134            // When input class is not in the domain, ID will be -1.
135            if (i < 0) {
136                continue;
137            }
138            DenseMatrix cm = mcm[i];
139            total += getter.apply(cm);
140        }
141        return total;
142    }
143
144    @Override
145    public double confusion(MultiLabel predicted, MultiLabel truth) {
146        double total = 0d;
147        Set<Label> trueSet = truth.getLabelSet();
148        Set<Label> predSet = predicted.getLabelSet();
149        for (Label predLabel : predSet) {
150            int idx = domain.getID(new MultiLabel(predLabel.getLabel()));
151            for (Label trueLabel : trueSet) {
152                int jdx = domain.getID(new MultiLabel(trueLabel.getLabel()));
153                total += this.confusion.get(idx, jdx);
154            }
155        }
156        return total;
157    }
158
159    @Override
160    public String toString() {
161        StringBuilder sb = new StringBuilder();
162        sb.append("[");
163        for (int i = 0; i < mcm.length; i++) {
164            DenseMatrix cm = mcm[i];
165            sb.append(cm.toString());
166            sb.append("\n");
167        }
168        sb.append("]");
169        return sb.toString();
170    }
171
172    static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {
173        // this just keeps track of how many times [class x] was predicted to be [class y]
174        DenseMatrix confusion = new DenseMatrix(domain.size(), domain.size());
175
176        DenseMatrix[] mcm = new DenseMatrix[domain.size()];
177        for (int i = 0; i < domain.size(); i++) {
178            mcm[i] = new DenseMatrix(2, 2);
179        }
180
181        int predIndex = 0;
182        for (Prediction<MultiLabel> prediction : predictions) {
183            MultiLabel predictedOutput = prediction.getOutput();
184            MultiLabel trueOutput = prediction.getExample().getOutput();
185            if (trueOutput.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) {
186                throw new IllegalArgumentException("The sentinel Unknown MultiLabel was used as a ground truth label at prediction number " + predIndex);
187            } else if (predictedOutput.equals(MultiLabelFactory.UNKNOWN_MULTILABEL)) {
188                throw new IllegalArgumentException("The sentinel Unknown MultiLabel was predicted by the model at prediction number " + predIndex);
189            }
190
191            Set<Label> trueSet = trueOutput.getLabelSet();
192            Set<Label> predSet = predictedOutput.getLabelSet();
193
194            //
195            // Count true positives and false positives
196            for (Label pred : predSet) {
197                int idx = domain.getID(new MultiLabel(pred.getLabel()));
198                if (trueSet.contains(pred)) {
199                    //
200                    // true positive: mcm[i, 1, 1]++
201                    mcm[idx].add(1, 1, 1d);
202                } else {
203                    //
204                    // false positive: mcm[i, 1, 0]++
205                    mcm[idx].add(1, 0, 1d);
206                }
207            }
208
209            //
210            // Count false negatives and populate the confusion table
211            for (Label trueLabel : trueSet) {
212                int idx = domain.getID(new MultiLabel(trueLabel.getLabel()));
213                if (idx < 0) {
214                    throw new IllegalArgumentException("Unknown label '" + trueLabel.getLabel() + "' found in the ground truth labels at prediction number " + predIndex
215                            + ", this label is not known by the model which made the predictions.");
216                }
217
218                //
219                // Doing two things in this loop:
220                // 1) Checking if predSet contains trueLabel
221                // 2) Counting the # of times [trueLabel] was predicted to be [predLabel] to populate the confusion table
222                boolean found = false;
223                for (Label predLabel : predSet) {
224                    int jdx = domain.getID(new MultiLabel(predLabel.getLabel()));
225                    confusion.add(jdx, idx, 1d);
226
227                    if (predLabel.equals(trueLabel)) {
228                        found = true;
229                    }
230                }
231
232                if (!found) {
233                    //
234                    // false negative: mcm[i, 0, 1]++
235                    mcm[idx].add(0, 1, 1d);
236                }
237                // else { true positive: already counted }
238            }
239
240            //
241            // True negatives everywhere else
242            for (MultiLabel multilabel : domain.getDomain()) {
243                Set<Label> labels = multilabel.getLabelSet();
244                for (Label label : labels) {
245                    if (!trueSet.contains(label) && !predSet.contains(label)) {
246                        int ix = domain.getID(new MultiLabel(label));
247                        mcm[ix].add(0, 0, 1d);
248                    }
249                }
250            }
251            predIndex++;
252        }
253
254        return new ConfusionMatrixTuple(mcm, confusion);
255    }
256
257    /**
258     * It's a record, ooops not yet, we don't require Java 14.
259     */
260    static final class ConfusionMatrixTuple {
261        final DenseMatrix[] mcm;
262        final DenseMatrix confusion;
263        ConfusionMatrixTuple(DenseMatrix[] mcm, DenseMatrix confusion) {
264            this.mcm = mcm;
265            this.confusion = confusion;
266        }
267
268        DenseMatrix[] getMCM() {
269            return mcm;
270        }
271    }
272}