001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import org.tribuo.Prediction;
020import org.tribuo.classification.Label;
021import org.tribuo.classification.evaluation.LabelMetric;
022import org.tribuo.classification.evaluation.LabelMetrics;
023import org.tribuo.evaluation.metrics.MetricID;
024import org.tribuo.evaluation.metrics.MetricTarget;
025import org.tribuo.provenance.EvaluationProvenance;
026import org.tribuo.sequence.AbstractSequenceEvaluator;
027import org.tribuo.sequence.SequenceModel;
028
029import java.util.ArrayList;
030import java.util.HashSet;
031import java.util.List;
032import java.util.Map;
033import java.util.Set;
034
035/**
036 * A sequence evaluator for labels.
037 */
038public class LabelSequenceEvaluator extends AbstractSequenceEvaluator<Label, LabelMetric.Context, LabelSequenceEvaluation, LabelMetric> {
039
040    @Override
041    protected Set<LabelMetric> createMetrics(SequenceModel<Label> model) {
042        Set<LabelMetric> metrics = new HashSet<>();
043        //
044        // Populate labelwise values
045        for (Label label : model.getOutputIDInfo().getDomain()) {
046            MetricTarget<Label> tgt = new MetricTarget<>(label);
047            metrics.add(LabelMetrics.TP.forTarget(tgt));
048            metrics.add(LabelMetrics.FP.forTarget(tgt));
049            metrics.add(LabelMetrics.TN.forTarget(tgt));
050            metrics.add(LabelMetrics.FN.forTarget(tgt));
051            metrics.add(LabelMetrics.PRECISION.forTarget(tgt));
052            metrics.add(LabelMetrics.RECALL.forTarget(tgt));
053            metrics.add(LabelMetrics.F1.forTarget(tgt));
054            metrics.add(LabelMetrics.ACCURACY.forTarget(tgt));
055        }
056
057        //
058        // Populate averaged values.
059        MetricTarget<Label> micro = MetricTarget.microAverageTarget();
060        metrics.add(LabelMetrics.TP.forTarget(micro));
061        metrics.add(LabelMetrics.FP.forTarget(micro));
062        metrics.add(LabelMetrics.TN.forTarget(micro));
063        metrics.add(LabelMetrics.FN.forTarget(micro));
064        metrics.add(LabelMetrics.PRECISION.forTarget(micro));
065        metrics.add(LabelMetrics.RECALL.forTarget(micro));
066        metrics.add(LabelMetrics.F1.forTarget(micro));
067        metrics.add(LabelMetrics.ACCURACY.forTarget(micro));
068
069        MetricTarget<Label> macro = MetricTarget.macroAverageTarget();
070        metrics.add(LabelMetrics.TP.forTarget(macro));
071        metrics.add(LabelMetrics.FP.forTarget(macro));
072        metrics.add(LabelMetrics.TN.forTarget(macro));
073        metrics.add(LabelMetrics.FN.forTarget(macro));
074        metrics.add(LabelMetrics.PRECISION.forTarget(macro));
075        metrics.add(LabelMetrics.RECALL.forTarget(macro));
076        metrics.add(LabelMetrics.F1.forTarget(macro));
077        metrics.add(LabelMetrics.ACCURACY.forTarget(macro));
078
079        // Target doesn't matter for balanced error rate, so we just use
080        // average.macro as it's the macro average of recalls.
081        metrics.add(LabelMetrics.BALANCED_ERROR_RATE.forTarget(macro));
082
083        return metrics;
084    }
085
086    @Override
087    protected LabelMetric.Context createContext(SequenceModel<Label> model, List<List<Prediction<Label>>> predictions) {
088        // Warning this passes a null in as the model.
089        return new LabelMetric.Context(model, flattenList(predictions));
090    }
091
092    @Override
093    protected LabelSequenceEvaluation createEvaluation(LabelMetric.Context ctx,
094                                               Map<MetricID<Label>, Double> results,
095                                               EvaluationProvenance provenance) {
096        return new LabelSequenceEvaluation(results, ctx, provenance);
097    }
098
099    private static List<Prediction<Label>> flattenList(List<List<Prediction<Label>>> predictions) {
100        List<Prediction<Label>> flatList = new ArrayList<>();
101
102        for (List<Prediction<Label>> list : predictions) {
103            flatList.addAll(list);
104        }
105
106        return flatList;
107    }
108}