001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import org.tribuo.DataSource; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.Model; 023import org.tribuo.Output; 024import org.tribuo.OutputFactory; 025import org.tribuo.Prediction; 026import org.tribuo.evaluation.metrics.EvaluationMetric; 027import org.tribuo.evaluation.metrics.MetricContext; 028import org.tribuo.evaluation.metrics.MetricID; 029import org.tribuo.provenance.DataProvenance; 030import org.tribuo.provenance.EvaluationProvenance; 031 032import java.util.ArrayList; 033import java.util.HashMap; 034import java.util.List; 035import java.util.Map; 036import java.util.Set; 037 038/** 039 * Base class for evaluators. 040 */ 041public abstract class AbstractEvaluator< 042 T extends Output<T>, 043 C extends MetricContext<T>, 044 E extends Evaluation<T>, 045 M extends EvaluationMetric<T, C>> implements Evaluator<T, E> { 046 047 /** 048 * Produces an evaluation for the supplied model and dataset, by calling {@link Model#predict} 049 * to create the predictions, then aggregating the appropriate statistics. 050 * @param model The model to use. 051 * @param dataset The dataset to make predictions for. 052 * @return An evaluation of the dataset on the model. 053 */ 054 @Override 055 public final E evaluate(Model<T> model, Dataset<T> dataset) { 056 OutputFactory<T> factory = dataset.getOutputFactory(); 057 int i = 0; 058 for (Example<T> example : dataset) { 059 if (factory.getUnknownOutput().equals(example.getOutput())) { 060 throw new IllegalArgumentException("The sentinel Unknown Output was used as a ground truth label in example number " + i); 061 } 062 i++; 063 } 064 // 065 // Run the model against the dataset to get predictions 066 List<Prediction<T>> predictions = model.predict(dataset); 067 return evaluate(model, predictions, dataset.getProvenance()); 068 } 069 070 /** 071 * Produces an evaluation for the supplied model and datasource, by calling {@link Model#predict} 072 * to create the predictions, then aggregating the appropriate statistics. 073 * @param model The model to use. 074 * @param datasource The datasource to make predictions for. 075 * @return An evaluation of the datasource on the model. 076 */ 077 @Override 078 public final E evaluate(Model<T> model, DataSource<T> datasource) { 079 OutputFactory<T> factory = datasource.getOutputFactory(); 080 List<Example<T>> examples = new ArrayList<>(); 081 for (Example<T> example : datasource) { 082 if (factory.getUnknownOutput().equals(example.getOutput())) { 083 throw new IllegalArgumentException("The sentinel Unknown Output was used as a ground truth label in example number " + examples.size()); 084 } 085 examples.add(example); 086 } 087 // 088 // Run the model against the dataset to get predictions 089 List<Prediction<T>> predictions = model.predict(examples); 090 return evaluate(model, predictions, datasource.getProvenance()); 091 } 092 093 // "template method" 094 095 /** 096 * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics. 097 * <p> 098 * Warning, this method cannot validate that the predictions were returned by the model in question. 099 * @param model The model to use. 100 * @param predictions The predictions to use. 101 * @param dataProvenance The provenance of the test data. 102 * @return An evaluation of the predictions. 103 */ 104 @Override 105 public final E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance) { 106 // 107 // Create the provenance for the model and dataset 108 EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance); 109 // 110 // Create an evaluation context. The context stores all the information needed by the list of metrics plus might 111 // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context') 112 C context = createContext(model, predictions); 113 // 114 // "MODEL": Build the list of metrics to compute. 115 Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model); 116 // 117 // "CONTROLLER": For each metric in the list, compute the result. 118 Map<MetricID<T>, Double> results = computeResults(context, metrics); 119 // 120 // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users 121 return createEvaluation(context, results, provenance); 122 } 123 124 /** 125 * Computes each metric given the context. 126 * @param ctx The metric context (i.e., the sufficient statistics). 127 * @param metrics The metrics to compute. 128 * @return The value of each requested metric. 129 */ 130 protected Map<MetricID<T>, Double> computeResults(C ctx, Set<? extends EvaluationMetric<T, C>> metrics) { 131 Map<MetricID<T>, Double> results = new HashMap<>(); 132 for (EvaluationMetric<T, C> metric : metrics) { 133 MetricID<T> id = metric.getID(); 134 double value = metric.compute(ctx); 135 results.put(id, value); 136 } 137 return results; 138 } 139 140 /** 141 * Creates the appropriate set of metrics for this model, by querying for it's {@link org.tribuo.OutputInfo}. 142 * @param model The model to inspect. 143 * @return The set of metrics. 144 */ 145 protected abstract Set<M> createMetrics(Model<T> model); 146 147 // 148 // Note: the following two methods are abstract (plus the 'C' type parameter) to make memoization work smoothly, basically. 149 150 /** 151 * Create the context needed for evaluation. The context might store global properties or cache computation. 152 * @param model the model that will be evaluated 153 * @param predictions the predictions that will be evaluated 154 * @return the context for this model and its predictions 155 */ 156 protected abstract C createContext(Model<T> model, List<Prediction<T>> predictions); 157 158 /** 159 * Create an evaluation for the given results 160 * @param context the context that was used to compute these results 161 * @param results the results 162 * @param provenance the provenance of the results (including information about the model and dataset) 163 * @return the evaluation 164 */ 165 protected abstract E createEvaluation(C context, 166 Map<MetricID<T>, Double> results, 167 EvaluationProvenance provenance); 168}