001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.classification.Label; 022import org.tribuo.evaluation.AbstractEvaluator; 023import org.tribuo.evaluation.Evaluator; 024import org.tribuo.evaluation.metrics.MetricID; 025import org.tribuo.evaluation.metrics.MetricTarget; 026import org.tribuo.provenance.EvaluationProvenance; 027 028import java.util.HashSet; 029import java.util.List; 030import java.util.Map; 031import java.util.Set; 032 033 034/** 035 * An {@link Evaluator} for {@link Label}s. 036 * <p> 037 * The default set of metrics is taken from {@link LabelMetrics}. If the supplied 038 * model generates probabilities, then it also calculates {@link LabelMetrics#AUCROC} and 039 * {@link LabelMetrics#AVERAGED_PRECISION}. 040 * <p> 041 * If the dataset contains an unknown Label (as generated by {@link org.tribuo.classification.LabelFactory#getUnknownOutput()}) 042 * or a valid Label which is outside of the domain of the {@link Model} then the evaluate methods will 043 * throw {@link IllegalArgumentException} with an appropriate message. 044 */ 045public final class LabelEvaluator extends AbstractEvaluator<Label, LabelMetric.Context, LabelEvaluation, LabelMetric> { 046 047 @Override 048 protected Set<LabelMetric> createMetrics(Model<Label> model) { 049 Set<LabelMetric> metrics = new HashSet<>(); 050 // 051 // Populate labelwise values 052 for (Label label : model.getOutputIDInfo().getDomain()) { 053 MetricTarget<Label> tgt = new MetricTarget<>(label); 054 metrics.add(LabelMetrics.TP.forTarget(tgt)); 055 metrics.add(LabelMetrics.FP.forTarget(tgt)); 056 metrics.add(LabelMetrics.TN.forTarget(tgt)); 057 metrics.add(LabelMetrics.FN.forTarget(tgt)); 058 metrics.add(LabelMetrics.PRECISION.forTarget(tgt)); 059 metrics.add(LabelMetrics.RECALL.forTarget(tgt)); 060 metrics.add(LabelMetrics.F1.forTarget(tgt)); 061 metrics.add(LabelMetrics.ACCURACY.forTarget(tgt)); 062 if (model.generatesProbabilities()) { 063 metrics.add(LabelMetrics.AUCROC.forTarget(tgt)); 064 metrics.add(LabelMetrics.AVERAGED_PRECISION.forTarget(tgt)); 065 } 066 } 067 068 // 069 // Populate averaged values. 070 MetricTarget<Label> micro = MetricTarget.microAverageTarget(); 071 metrics.add(LabelMetrics.TP.forTarget(micro)); 072 metrics.add(LabelMetrics.FP.forTarget(micro)); 073 metrics.add(LabelMetrics.TN.forTarget(micro)); 074 metrics.add(LabelMetrics.FN.forTarget(micro)); 075 metrics.add(LabelMetrics.PRECISION.forTarget(micro)); 076 metrics.add(LabelMetrics.RECALL.forTarget(micro)); 077 metrics.add(LabelMetrics.F1.forTarget(micro)); 078 metrics.add(LabelMetrics.ACCURACY.forTarget(micro)); 079 080 MetricTarget<Label> macro = MetricTarget.macroAverageTarget(); 081 metrics.add(LabelMetrics.TP.forTarget(macro)); 082 metrics.add(LabelMetrics.FP.forTarget(macro)); 083 metrics.add(LabelMetrics.TN.forTarget(macro)); 084 metrics.add(LabelMetrics.FN.forTarget(macro)); 085 metrics.add(LabelMetrics.PRECISION.forTarget(macro)); 086 metrics.add(LabelMetrics.RECALL.forTarget(macro)); 087 metrics.add(LabelMetrics.F1.forTarget(macro)); 088 metrics.add(LabelMetrics.ACCURACY.forTarget(macro)); 089 090 // Target doesn't matter for balanced error rate, so we just use 091 // average.macro as it's the macro average of recalls. 092 metrics.add(LabelMetrics.BALANCED_ERROR_RATE.forTarget(macro)); 093 094 return metrics; 095 } 096 097 @Override 098 protected LabelMetric.Context createContext(Model<Label> model, List<Prediction<Label>> predictions) { 099 return new LabelMetric.Context(model, predictions); 100 } 101 102 @Override 103 protected LabelEvaluation createEvaluation(LabelMetric.Context ctx, 104 Map<MetricID<Label>, Double> results, 105 EvaluationProvenance provenance) { 106 return new LabelEvaluationImpl(results, ctx, provenance); 107 } 108}