001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.classification.Label;
022import org.tribuo.evaluation.AbstractEvaluator;
023import org.tribuo.evaluation.Evaluator;
024import org.tribuo.evaluation.metrics.MetricID;
025import org.tribuo.evaluation.metrics.MetricTarget;
026import org.tribuo.provenance.EvaluationProvenance;
027
028import java.util.HashSet;
029import java.util.List;
030import java.util.Map;
031import java.util.Set;
032
033
034/**
035 * An {@link Evaluator} for {@link Label}s.
036 * <p>
037 * The default set of metrics is taken from {@link LabelMetrics}. If the supplied
038 * model generates probabilities, then it also calculates {@link LabelMetrics#AUCROC} and
039 * {@link LabelMetrics#AVERAGED_PRECISION}.
040 * <p>
041 * If the dataset contains an unknown Label (as generated by {@link org.tribuo.classification.LabelFactory#getUnknownOutput()})
042 * or a valid Label which is outside of the domain of the {@link Model} then the evaluate methods will
043 * throw {@link IllegalArgumentException} with an appropriate message.
044 */
045public final class LabelEvaluator extends AbstractEvaluator<Label, LabelMetric.Context, LabelEvaluation, LabelMetric> {
046
047    @Override
048    protected Set<LabelMetric> createMetrics(Model<Label> model) {
049        Set<LabelMetric> metrics = new HashSet<>();
050        //
051        // Populate labelwise values
052        for (Label label : model.getOutputIDInfo().getDomain()) {
053            MetricTarget<Label> tgt = new MetricTarget<>(label);
054            metrics.add(LabelMetrics.TP.forTarget(tgt));
055            metrics.add(LabelMetrics.FP.forTarget(tgt));
056            metrics.add(LabelMetrics.TN.forTarget(tgt));
057            metrics.add(LabelMetrics.FN.forTarget(tgt));
058            metrics.add(LabelMetrics.PRECISION.forTarget(tgt));
059            metrics.add(LabelMetrics.RECALL.forTarget(tgt));
060            metrics.add(LabelMetrics.F1.forTarget(tgt));
061            metrics.add(LabelMetrics.ACCURACY.forTarget(tgt));
062            if (model.generatesProbabilities()) {
063                metrics.add(LabelMetrics.AUCROC.forTarget(tgt));
064                metrics.add(LabelMetrics.AVERAGED_PRECISION.forTarget(tgt));
065            }
066        }
067
068        //
069        // Populate averaged values.
070        MetricTarget<Label> micro = MetricTarget.microAverageTarget();
071        metrics.add(LabelMetrics.TP.forTarget(micro));
072        metrics.add(LabelMetrics.FP.forTarget(micro));
073        metrics.add(LabelMetrics.TN.forTarget(micro));
074        metrics.add(LabelMetrics.FN.forTarget(micro));
075        metrics.add(LabelMetrics.PRECISION.forTarget(micro));
076        metrics.add(LabelMetrics.RECALL.forTarget(micro));
077        metrics.add(LabelMetrics.F1.forTarget(micro));
078        metrics.add(LabelMetrics.ACCURACY.forTarget(micro));
079
080        MetricTarget<Label> macro = MetricTarget.macroAverageTarget();
081        metrics.add(LabelMetrics.TP.forTarget(macro));
082        metrics.add(LabelMetrics.FP.forTarget(macro));
083        metrics.add(LabelMetrics.TN.forTarget(macro));
084        metrics.add(LabelMetrics.FN.forTarget(macro));
085        metrics.add(LabelMetrics.PRECISION.forTarget(macro));
086        metrics.add(LabelMetrics.RECALL.forTarget(macro));
087        metrics.add(LabelMetrics.F1.forTarget(macro));
088        metrics.add(LabelMetrics.ACCURACY.forTarget(macro));
089
090        // Target doesn't matter for balanced error rate, so we just use
091        // average.macro as it's the macro average of recalls.
092        metrics.add(LabelMetrics.BALANCED_ERROR_RATE.forTarget(macro));
093
094        return metrics;
095    }
096
097    @Override
098    protected LabelMetric.Context createContext(Model<Label> model, List<Prediction<Label>> predictions) {
099        return new LabelMetric.Context(model, predictions);
100    }
101
102    @Override
103    protected LabelEvaluation createEvaluation(LabelMetric.Context ctx,
104                                               Map<MetricID<Label>, Double> results,
105                                               EvaluationProvenance provenance) {
106        return new LabelEvaluationImpl(results, ctx, provenance);
107    }
108}