001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.evaluation;
018
019import org.tribuo.Prediction;
020import org.tribuo.classification.evaluation.ConfusionMatrix;
021import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
022import org.tribuo.evaluation.metrics.MetricID;
023import org.tribuo.evaluation.metrics.MetricTarget;
024import org.tribuo.multilabel.MultiLabel;
025import org.tribuo.provenance.EvaluationProvenance;
026
027import java.util.ArrayList;
028import java.util.Collections;
029import java.util.List;
030import java.util.Map;
031
032
033/**
034 * The implementation of a {@link MultiLabelEvaluation} using the default metrics.
035 * <p>
036 * The classification metrics consider labels independently.
037 */
038public final class MultiLabelEvaluationImpl implements MultiLabelEvaluation {
039
040    private final Map<MetricID<MultiLabel>, Double> results;
041    private final MultiLabelMetric.Context context;
042    private final ConfusionMatrix<MultiLabel> cm;
043    private final EvaluationProvenance provenance;
044
045    /**
046     * Builds an evaluation using the supplied metric results, confusion matrix and evaluation provenance.
047     * @param results The results.
048     * @param context The context carrying the confusion matrix
049     * @param provenance The evaluation provenance.
050     */
051    MultiLabelEvaluationImpl(Map<MetricID<MultiLabel>, Double> results,
052                             MultiLabelMetric.Context context,
053                             EvaluationProvenance provenance) {
054        this.results = results;
055        this.context = context;
056        this.cm = context.getCM();
057        this.provenance = provenance;
058    }
059
060    @Override
061    public List<Prediction<MultiLabel>> getPredictions() {
062        return context.getPredictions();
063    }
064
065    @Override
066    public double balancedErrorRate() {
067        // Target doesn't matter for balanced error rate, so we just use Average.macro
068        // as it's the macro average of the recalls.
069        MetricTarget<MultiLabel> dummy = MetricTarget.macroAverageTarget();
070        return get(dummy, MultiLabelMetrics.BALANCED_ERROR_RATE);
071    }
072
073    @Override
074    public ConfusionMatrix<MultiLabel> getConfusionMatrix() {
075        return cm;
076    }
077
078    @Override
079    public double confusion(MultiLabel predicted, MultiLabel truth) {
080        return cm.confusion(predicted, truth);
081    }
082
083    @Override
084    public double tp(MultiLabel label) {
085        return get(label, MultiLabelMetrics.TP);
086    }
087
088    @Override
089    public double tp() {
090        return get(Average.MICRO, MultiLabelMetrics.TP);
091    }
092
093    @Override
094    public double macroTP() {
095        return get(Average.MACRO, MultiLabelMetrics.TP);
096    }
097
098    @Override
099    public double fp(MultiLabel label) {
100        return get(label, MultiLabelMetrics.FP);
101    }
102
103    @Override
104    public double fp() {
105        return get(Average.MICRO, MultiLabelMetrics.FP);
106    }
107
108    @Override
109    public double macroFP() {
110        return get(Average.MACRO, MultiLabelMetrics.FP);
111    }
112
113    @Override
114    public double tn(MultiLabel label) {
115        return get(label, MultiLabelMetrics.TN);
116    }
117
118    @Override
119    public double tn() {
120        return get(Average.MICRO, MultiLabelMetrics.TN);
121    }
122
123    @Override
124    public double macroTN() {
125        return get(Average.MACRO, MultiLabelMetrics.TN);
126    }
127
128    @Override
129    public double fn(MultiLabel label) {
130        return get(label, MultiLabelMetrics.FN);
131    }
132
133    @Override
134    public double fn() {
135        return get(Average.MICRO, MultiLabelMetrics.FN);
136    }
137
138    @Override
139    public double macroFN() {
140        return get(Average.MACRO, MultiLabelMetrics.FN);
141    }
142
143    @Override
144    public double precision(MultiLabel label) {
145        return get(new MetricTarget<>(label), MultiLabelMetrics.PRECISION);
146    }
147
148    @Override
149    public double microAveragedPrecision() {
150        return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.PRECISION);
151    }
152
153    @Override
154    public double macroAveragedPrecision() {
155        return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.PRECISION);
156    }
157
158    @Override
159    public double recall(MultiLabel label) {
160        return get(new MetricTarget<>(label), MultiLabelMetrics.RECALL);
161    }
162
163    @Override
164    public double microAveragedRecall() {
165        return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.RECALL);
166    }
167
168    @Override
169    public double macroAveragedRecall() {
170        return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.RECALL);
171    }
172
173    @Override
174    public double f1(MultiLabel label) {
175        return get(new MetricTarget<>(label), MultiLabelMetrics.F1);
176    }
177
178    @Override
179    public double microAveragedF1() {
180        return get(new MetricTarget<>(Average.MICRO), MultiLabelMetrics.F1);
181    }
182
183    @Override
184    public double macroAveragedF1() {
185        return get(new MetricTarget<>(Average.MACRO), MultiLabelMetrics.F1);
186    }
187
188    @Override
189    public Map<MetricID<MultiLabel>, Double> asMap() {
190        return Collections.unmodifiableMap(results);
191    }
192
193    @Override
194    public EvaluationProvenance getProvenance() {
195        return provenance;
196    }
197
198    @Override
199    public String toString() {
200        List<MultiLabel> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
201        return toString(labelOrder);
202    }
203
204    private String toString(List<MultiLabel> labelOrder) {
205        StringBuilder sb = new StringBuilder();
206        int tp = 0;
207        int fn = 0;
208        int fp = 0;
209        int n = 0;
210        //
211        // Figure out the biggest class label and therefore the format string
212        // that we should use for them.
213        int maxLabelSize = "Balanced Error Rate".length();
214        for(MultiLabel label : labelOrder) {
215            maxLabelSize = Math.max(maxLabelSize, label.getLabelString().length());
216        }
217        String labelFormatString = String.format("%%-%ds", maxLabelSize+2);
218        sb.append(String.format(labelFormatString, "Class"));
219        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
220        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
221        for (MultiLabel label : labelOrder) {
222            if (cm.support(label) == 0) {
223                continue;
224            }
225            n += cm.support(label);
226            tp += cm.tp(label);
227            fn += cm.fn(label);
228            fp += cm.fp(label);
229            sb.append(String.format(labelFormatString, label));
230            sb.append(String.format("%,12d%,12d%,12d%,12d",
231                    (int) cm.support(label),
232                    (int) cm.tp(label),
233                    (int) cm.fn(label),
234                    (int) cm.fp(label)
235            ));
236            sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label)));
237        }
238        sb.append(String.format(labelFormatString, "Total"));
239        sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
240        sb.append(String.format(labelFormatString, "Accuracy"));
241        sb.append(String.format("%60.3f%n", ((double) tp) / n));
242        sb.append(String.format(labelFormatString, "Micro Average"));
243        sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1()));
244        sb.append(String.format(labelFormatString, "Macro Average"));
245        sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1()));
246        sb.append(String.format(labelFormatString, "Balanced Error Rate"));
247        sb.append(String.format("%60.3f", balancedErrorRate()));
248        return sb.toString();
249    }
250
251    private double get(MetricTarget<MultiLabel> tgt, MultiLabelMetrics metric) {
252        return get(metric.forTarget(tgt).getID());
253    }
254
255    private double get(MultiLabel label, MultiLabelMetrics metric) {
256        return get(metric
257                .forTarget(new MetricTarget<>(label))
258                .getID());
259    }
260
261    private double get(Average avg, MultiLabelMetrics metric) {
262        return get(metric
263                .forTarget(new MetricTarget<>(avg))
264                .getID());
265    }
266
267}