001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import org.tribuo.Prediction;
020import org.tribuo.classification.Label;
021import org.tribuo.classification.evaluation.ConfusionMatrix;
022import org.tribuo.classification.evaluation.LabelMetric;
023import org.tribuo.classification.evaluation.LabelMetrics;
024import org.tribuo.evaluation.metrics.EvaluationMetric;
025import org.tribuo.evaluation.metrics.MetricID;
026import org.tribuo.evaluation.metrics.MetricTarget;
027import org.tribuo.provenance.EvaluationProvenance;
028import org.tribuo.sequence.SequenceEvaluation;
029
030import java.util.ArrayList;
031import java.util.Collections;
032import java.util.List;
033import java.util.Map;
034import java.util.logging.Logger;
035
036/**
037 * A class that can be used to evaluate a sequence label classification model element wise on a given set of data.
038 */
039public class LabelSequenceEvaluation implements SequenceEvaluation<Label> {
040
041    private static final Logger logger = Logger.getLogger(LabelSequenceEvaluation.class.getName());
042    
043    private final Map<MetricID<Label>, Double> results;
044    private final LabelMetric.Context ctx;
045    private final ConfusionMatrix<Label> cm;
046    private final EvaluationProvenance provenance;
047
048    protected LabelSequenceEvaluation(Map<MetricID<Label>, Double> results,
049                                      LabelMetric.Context ctx,
050                                      EvaluationProvenance provenance) {
051        this.results = results;
052        this.ctx = ctx;
053        this.cm = ctx.getCM();
054        this.provenance = provenance;
055    }
056
057    /**
058     * Gets the flattened predictions.
059     * @return The flattened predictions.
060     */
061    public List<Prediction<Label>> getPredictions() {
062        return ctx.getPredictions();
063    }
064
065    /**
066     * Gets the confusion matrix backing this evaluation.
067     * @return The confusion matrix.
068     */
069    public ConfusionMatrix<Label> getConfusionMatrix() {
070        return cm;
071    }
072
073    @Override
074    public Map<MetricID<Label>, Double> asMap() {
075        return Collections.unmodifiableMap(results);
076    }
077
078    /**
079     * Note: confusion is not stored in the underlying map, so it won't show up in aggregation.
080     * @param predictedLabel The predicted label.
081     * @param trueLabel The true label.
082     * @return The number of times that {@code predictedLabel} was predicted for <code>trueLabel</code>.
083     */
084    public double confusion(Label predictedLabel, Label trueLabel) {
085        return cm.confusion(predictedLabel, trueLabel);
086    }
087
088    public double tp(Label label) {
089        return get(label, LabelMetrics.TP);
090    }
091
092    public double tp() {
093        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP);
094    }
095
096    public double macroTP() {
097        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP);
098    }
099
100    public double fp(Label label) {
101        return get(label, LabelMetrics.FP);
102    }
103
104    public double fp() {
105        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP);
106    }
107
108    public double macroFP() {
109        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP);
110    }
111
112    public double tn(Label label) {
113        return get(label, LabelMetrics.TN);
114    }
115
116    public double tn() {
117        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN);
118    }
119
120    public double macroTN() {
121        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN);
122    }
123
124    public double fn(Label label) {
125        return get(label, LabelMetrics.FN);
126    }
127
128    public double fn() {
129        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN);
130    }
131
132    public double macroFN() {
133        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN);
134    }
135
136    public double precision(Label label) {
137        return get(label, LabelMetrics.PRECISION);
138    }
139
140    public double microAveragedPrecision() {
141        return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION);
142    }
143
144    public double macroAveragedPrecision() {
145        return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION);
146    }
147
148    public double recall(Label label) {
149        return get(label, LabelMetrics.RECALL);
150    }
151
152    public double microAveragedRecall() {
153        return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL);
154    }
155
156    public double macroAveragedRecall() {
157        return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL);
158    }
159
160    public double f1(Label label) {
161        return get(label, LabelMetrics.RECALL);
162    }
163
164    public double microAveragedF1() {
165        return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1);
166    }
167
168    public double macroAveragedF1() {
169        return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1);
170    }
171
172    public double accuracy() {
173        return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY);
174    }
175
176    public double accuracy(Label label) {
177        return get(label, LabelMetrics.ACCURACY);
178    }
179
180    public double balancedErrorRate() {
181        // Target doesn't matter for balanced error rate, so we just use Average.macro
182        // as it's the macro averaged recall.
183        return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE);
184    }
185
186    @Override
187    public EvaluationProvenance getProvenance() { return provenance; }
188
189    @Override
190    public String toString() {
191        List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
192        StringBuilder sb = new StringBuilder();
193        int tp = 0;
194        int fn = 0;
195        int fp = 0;
196        int n = 0;
197        //
198        // Figure out the biggest class label and therefore the format string
199        // that we should use for them.
200        int maxLabelSize = "Balanced Error Rate".length();
201        for(Label label : labelOrder) {
202            maxLabelSize = Math.max(maxLabelSize, label.getLabel().length());
203        }
204        String labelFormatString = String.format("%%-%ds", maxLabelSize+2);
205        sb.append(String.format(labelFormatString, "Class"));
206        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
207        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
208        for (Label label : labelOrder) {
209            if (cm.support(label) == 0) {
210                continue;
211            }
212            n += cm.support(label);
213            tp += cm.tp(label);
214            fn += cm.fn(label);
215            fp += cm.fp(label);
216            sb.append(String.format(labelFormatString, label));
217            sb.append(String.format("%,12d%,12d%,12d%,12d",
218                    (int) cm.support(label),
219                    (int) cm.tp(label),
220                    (int) cm.fn(label),
221                    (int) cm.fp(label)
222            ));
223            sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label)));
224        }
225        sb.append(String.format(labelFormatString, "Total"));
226        sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
227        sb.append(String.format(labelFormatString, "Accuracy"));
228        sb.append(String.format("%60.3f%n", (double) tp / n));
229        sb.append(String.format(labelFormatString, "Micro Average"));
230        sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1()));
231        sb.append(String.format(labelFormatString, "Macro Average"));
232        sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1()));
233        sb.append(String.format(labelFormatString, "Balanced Error Rate"));
234        sb.append(String.format("%60.3f", balancedErrorRate()));
235        return sb.toString();
236    }
237
238    private double get(MetricTarget<Label> tgt, LabelMetrics metric) {
239        return get(metric.forTarget(tgt).getID());
240    }
241
242    private double get(Label label, LabelMetrics metric) {
243        return get(metric
244                .forTarget(new MetricTarget<>(label))
245                .getID());
246    }
247
248    private double get(EvaluationMetric.Average avg, LabelMetrics metric) {
249        return get(metric
250                .forTarget(new MetricTarget<>(avg))
251                .getID());
252    }
253}