001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.evaluation.AbstractEvaluator; 022import org.tribuo.evaluation.Evaluator; 023import org.tribuo.evaluation.metrics.EvaluationMetric.Average; 024import org.tribuo.evaluation.metrics.MetricID; 025import org.tribuo.evaluation.metrics.MetricTarget; 026import org.tribuo.multilabel.MultiLabel; 027import org.tribuo.provenance.EvaluationProvenance; 028 029import java.util.HashSet; 030import java.util.List; 031import java.util.Map; 032import java.util.Set; 033 034 035/** 036 * An {@link Evaluator} for {@link MultiLabel} problems. 037 * <p> 038 * If the dataset contains an unknown MultiLabel (as generated by {@link org.tribuo.multilabel.MultiLabelFactory#getUnknownOutput()}) 039 * or a valid MultiLabel which is outside of the domain of the {@link Model} then the evaluate methods will 040 * throw {@link IllegalArgumentException} with an appropriate message. 041 */ 042public class MultiLabelEvaluator extends AbstractEvaluator<MultiLabel, MultiLabelMetric.Context, MultiLabelEvaluation, MultiLabelMetric> { 043 044 @Override 045 protected Set<MultiLabelMetric> createMetrics(Model<MultiLabel> model) { 046 Set<MultiLabelMetric> metrics = new HashSet<>(); 047 // 048 // Populate labelwise values 049 for (MultiLabel label : model.getOutputIDInfo().getDomain()) { 050 MetricTarget<MultiLabel> tgt = new MetricTarget<>(label); 051 metrics.add(MultiLabelMetrics.TP.forTarget(tgt)); 052 metrics.add(MultiLabelMetrics.FP.forTarget(tgt)); 053 metrics.add(MultiLabelMetrics.TN.forTarget(tgt)); 054 metrics.add(MultiLabelMetrics.FN.forTarget(tgt)); 055 metrics.add(MultiLabelMetrics.PRECISION.forTarget(tgt)); 056 metrics.add(MultiLabelMetrics.RECALL.forTarget(tgt)); 057 metrics.add(MultiLabelMetrics.F1.forTarget(tgt)); 058 } 059 060 // 061 // Populate averaged values. 062 for (Average avg : Average.values()) { 063 MetricTarget<MultiLabel> tgt = new MetricTarget<>(avg); 064 metrics.add(MultiLabelMetrics.TP.forTarget(tgt)); 065 metrics.add(MultiLabelMetrics.FP.forTarget(tgt)); 066 metrics.add(MultiLabelMetrics.TN.forTarget(tgt)); 067 metrics.add(MultiLabelMetrics.FN.forTarget(tgt)); 068 metrics.add(MultiLabelMetrics.PRECISION.forTarget(tgt)); 069 metrics.add(MultiLabelMetrics.RECALL.forTarget(tgt)); 070 metrics.add(MultiLabelMetrics.F1.forTarget(tgt)); 071 } 072 073 // Target doesn't matter for balanced error rate, so we just use Average.macro 074 // as it's the macro average of the recalls. 075 MetricTarget<MultiLabel> dummy = new MetricTarget<>(Average.MACRO); 076 metrics.add(MultiLabelMetrics.BALANCED_ERROR_RATE.forTarget(dummy)); 077 078 return metrics; 079 } 080 081 @Override 082 protected MultiLabelMetric.Context createContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) { 083 return MultiLabelMetric.buildContext(model, predictions); 084 } 085 086 @Override 087 protected MultiLabelEvaluation createEvaluation(MultiLabelMetric.Context context, 088 Map<MetricID<MultiLabel>, Double> results, 089 EvaluationProvenance provenance) { 090 return new MultiLabelEvaluationImpl(results, context, provenance); 091 } 092}