001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.evaluation.AbstractEvaluator;
022import org.tribuo.evaluation.Evaluator;
023import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
024import org.tribuo.evaluation.metrics.MetricID;
025import org.tribuo.evaluation.metrics.MetricTarget;
026import org.tribuo.multilabel.MultiLabel;
027import org.tribuo.provenance.EvaluationProvenance;
028
029import java.util.HashSet;
030import java.util.List;
031import java.util.Map;
032import java.util.Set;
033
034
035/**
036 * An {@link Evaluator} for {@link MultiLabel} problems.
037 * <p>
038 * If the dataset contains an unknown MultiLabel (as generated by {@link org.tribuo.multilabel.MultiLabelFactory#getUnknownOutput()})
039 * or a valid MultiLabel which is outside of the domain of the {@link Model} then the evaluate methods will
040 * throw {@link IllegalArgumentException} with an appropriate message.
041 */
042public class MultiLabelEvaluator extends AbstractEvaluator<MultiLabel, MultiLabelMetric.Context, MultiLabelEvaluation, MultiLabelMetric> {
043
044    @Override
045    protected Set<MultiLabelMetric> createMetrics(Model<MultiLabel> model) {
046        Set<MultiLabelMetric> metrics = new HashSet<>();
047        //
048        // Populate labelwise values
049        for (MultiLabel label : model.getOutputIDInfo().getDomain()) {
050            MetricTarget<MultiLabel> tgt = new MetricTarget<>(label);
051            metrics.add(MultiLabelMetrics.TP.forTarget(tgt));
052            metrics.add(MultiLabelMetrics.FP.forTarget(tgt));
053            metrics.add(MultiLabelMetrics.TN.forTarget(tgt));
054            metrics.add(MultiLabelMetrics.FN.forTarget(tgt));
055            metrics.add(MultiLabelMetrics.PRECISION.forTarget(tgt));
056            metrics.add(MultiLabelMetrics.RECALL.forTarget(tgt));
057            metrics.add(MultiLabelMetrics.F1.forTarget(tgt));
058        }
059
060        //
061        // Populate averaged values.
062        for (Average avg : Average.values()) {
063            MetricTarget<MultiLabel> tgt = new MetricTarget<>(avg);
064            metrics.add(MultiLabelMetrics.TP.forTarget(tgt));
065            metrics.add(MultiLabelMetrics.FP.forTarget(tgt));
066            metrics.add(MultiLabelMetrics.TN.forTarget(tgt));
067            metrics.add(MultiLabelMetrics.FN.forTarget(tgt));
068            metrics.add(MultiLabelMetrics.PRECISION.forTarget(tgt));
069            metrics.add(MultiLabelMetrics.RECALL.forTarget(tgt));
070            metrics.add(MultiLabelMetrics.F1.forTarget(tgt));
071        }
072
073        // Target doesn't matter for balanced error rate, so we just use Average.macro
074        // as it's the macro average of the recalls.
075        MetricTarget<MultiLabel> dummy = new MetricTarget<>(Average.MACRO);
076        metrics.add(MultiLabelMetrics.BALANCED_ERROR_RATE.forTarget(dummy));
077
078        return metrics;
079    }
080
081    @Override
082    protected MultiLabelMetric.Context createContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) {
083        return MultiLabelMetric.buildContext(model, predictions);
084    }
085
086    @Override
087    protected MultiLabelEvaluation createEvaluation(MultiLabelMetric.Context context,
088                                                        Map<MetricID<MultiLabel>, Double> results,
089                                                        EvaluationProvenance provenance) {
090        return new MultiLabelEvaluationImpl(results, context, provenance);
091    }
092}