001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.ImmutableOutputInfo;
020import org.tribuo.Model;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.math.la.DenseMatrix;
024
025import java.util.ArrayList;
026import java.util.HashMap;
027import java.util.HashSet;
028import java.util.LinkedHashSet;
029import java.util.List;
030import java.util.Map;
031import java.util.Set;
032import java.util.function.ToDoubleFunction;
033import java.util.logging.Logger;
034
035/**
036 * A confusion matrix for {@link Label}s.
037 * <p>
038 * We interpret it as follows:
039 *
040 * {@code
041 * C[i, j] = k
042 * }
043 *
044 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times".
045 *
046 * <p>
047 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to
048 * the ground truth.
049 * </p>
050 */
051public final class LabelConfusionMatrix implements ConfusionMatrix<Label> {
052
053    private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName());
054
055    private final ImmutableOutputInfo<Label> domain;
056
057    private final int total;
058    private final Map<Label, Double> occurrences;
059
060    private final Set<Label> observed;
061
062    private final DenseMatrix cm;
063
064    private List<Label> labelOrder;
065
066    /**
067     * Creates a confusion matrix from the supplied predictions, using the label info
068     * from the supplied model.
069     * @param model The model to use for the label information.
070     * @param predictions The predictions.
071     */
072    public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) {
073        this(model.getOutputIDInfo(), predictions);
074    }
075
076    /**
077     * Creates a confusion matrix from the supplied predictions and label info.
078     * @throws IllegalArgumentException If the domain doesn't contain all the predictions.
079     * @param domain The label information.
080     * @param predictions The predictions.
081     */
082    public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) {
083        this.domain = domain;
084        this.total = predictions.size();
085        this.cm = new DenseMatrix(domain.size(), domain.size());
086        this.occurrences = new HashMap<>();
087        this.observed = new HashSet<>();
088        tabulate(predictions);
089    }
090
091    /**
092     * Aggregate the predictions into this confusion matrix.
093     * @param predictions The predictions to aggregate.
094     */
095    private void tabulate(List<Prediction<Label>> predictions) {
096        predictions.forEach(prediction -> {
097            Label y = prediction.getExample().getOutput();
098            Label p = prediction.getOutput();
099            //
100            // Check that the ground truth label is valid
101            if (y.getLabel().equals(Label.UNKNOWN)) {
102                throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate.");
103            }
104            occurrences.merge(y,1d, Double::sum);
105            observed.add(y);
106            observed.add(p);
107            int iy = getIDOrThrow(y);
108            int ip = getIDOrThrow(p);
109            cm.add(ip, iy, 1d);
110        });
111    }
112
113    @Override
114    public ImmutableOutputInfo<Label> getDomain() {
115        return domain;
116    }
117
118    @Override
119    public double support() {
120        return total;
121    }
122
123    @Override
124    public double support(Label label) {
125        return occurrences.getOrDefault(label, 0d);
126    }
127
128    @Override
129    public double tp(Label cls) {
130        return compute(cls, (i) -> cm.get(i, i));
131    }
132
133    @Override
134    public double fp(Label cls) {
135        // Row-wise sum less true positives
136        return compute(cls, i -> cm.rowSum(i) - cm.get(i, i));
137    }
138
139    @Override
140    public double fn(Label cls) {
141        // Column-wise sum less true positives
142        return compute(cls, i -> cm.columnSum(i) - cm.get(i, i));
143    }
144
145    @Override
146    public double tn(Label cls) {
147        int n = getDomain().size();
148        int i = getDomain().getID(cls);
149        double total = 0d;
150        for (int j = 0; j < n; j++) {
151            if (j == i) {
152                continue;
153            }
154            for (int k = 0; k < n; k++) {
155                if (k == i) {
156                    continue;
157                }
158                total += cm.get(j, k);
159            }
160        }
161        return total;
162    }
163
164    @Override
165    public double confusion(Label predicted, Label trueClass) {
166        int i = getDomain().getID(predicted);
167        int j = getDomain().getID(trueClass);
168        return cm.get(i, j);
169    }
170
171    /**
172     * A convenience method for extracting the appropriate label statistic.
173     * @param cls The label to check.
174     * @param getter The get function which accepts a label id.
175     * @return The statistic for that label id.
176     */
177    private double compute(Label cls, ToDoubleFunction<Integer> getter) {
178        int i = getDomain().getID(cls);
179        if (i < 0) {
180            logger.fine("Unknown Label " + cls);
181            return 0d;
182        }
183        return getter.applyAsDouble(i);
184    }
185
186    /**
187     * Gets the id for the supplied label, or throws an {@link IllegalArgumentException} if it's
188     * an unknown label.
189     * @param key The label.
190     * @return The int id for that label.
191     */
192    private int getIDOrThrow(Label key) {
193        int id = domain.getID(key);
194        if (id < 0) {
195            throw new IllegalArgumentException("Unknown label: " + key);
196        }
197        return id;
198    }
199
200    /**
201     * Sets the label order used in {@link #toString}.
202     * @param labelOrder The label order to use.
203     */
204    public void setLabelOrder(List<Label> labelOrder) {
205        this.labelOrder = labelOrder;
206    }
207
208    @Override
209    public String toString() {
210        if (labelOrder == null) {
211            labelOrder = new ArrayList<>(domain.getDomain());
212        }
213        labelOrder.retainAll(observed);
214        
215        int maxLen = Integer.MIN_VALUE;
216        for (Label label : labelOrder) {
217            maxLen = Math.max(label.getLabel().length(), maxLen);
218            maxLen = Math.max(String.format(" %,d", (int)(double)occurrences.getOrDefault(label,0.0)).length(), maxLen);
219        }
220
221        StringBuilder sb = new StringBuilder();
222        String trueLabelFormat = String.format("%%-%ds", maxLen + 2);
223        String predictedLabelFormat = String.format("%%%ds", maxLen + 2);
224        String countFormat = String.format("%%,%dd", maxLen + 2);
225
226        //
227        // Empty spot in first row for labels on subsequent rows.
228        sb.append(String.format(trueLabelFormat, ""));
229
230        //
231        // Labels across the top for predicted.
232        for (Label predictedLabel : labelOrder) {
233            sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel()));
234        }
235        sb.append('\n');
236
237        for (Label trueLabel : labelOrder) {
238            sb.append(String.format(trueLabelFormat, trueLabel.getLabel()));
239            for (Label predictedLabel : labelOrder) {
240                int confusion = (int) confusion(predictedLabel, trueLabel);
241                sb.append(String.format(countFormat, confusion));
242            }
243            sb.append('\n');
244        }
245        return sb.toString();
246    }
247
248    /**
249     * Emits a HTML table representation of the Confusion Matrix.
250     * @return The confusion matrix as a HTML table.
251     */
252    public String toHTML() {
253        if (labelOrder == null) {
254            labelOrder = new ArrayList<>(domain.getDomain());
255        }
256        Set<Label> labelsToPrint = new LinkedHashSet<>(labelOrder);
257        labelsToPrint.retainAll(observed);
258        StringBuilder sb = new StringBuilder();
259        sb.append("<table>\n");
260        sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", occurrences.size() + 1));
261        sb.append("<tr><th></th>");
262        for (Label predictedLabel : labelsToPrint) {
263            sb.append("<th style=\"text-align:right\">")
264                    .append(predictedLabel)
265                    .append("</th>");
266        }
267        sb.append("<th style=\"text-align:right\">Total</th>");
268        sb.append("</tr>\n");
269        for (Label trueLabel : labelsToPrint) {
270            sb.append("<tr><th>").append(trueLabel).append("</th>");
271            double count = occurrences.getOrDefault(trueLabel, 0d);
272            for (Label predictedLabel : labelsToPrint) {
273                double tlmc = confusion(predictedLabel,trueLabel);
274                double percent = (tlmc / count) * 100;
275                sb.append("<td style=\"text-align:right\">")
276                        .append(String.format("%,d (%.1f%%)", (int)tlmc, percent))
277                        .append("</td>");
278            }
279            sb.append("<td style=\"text-align:right\">").append(count).append("</td>");
280            sb.append("</tr>\n");
281        }
282        sb.append("</table>");
283        return sb.toString();
284    }
285}