001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.sequence; 018 019import org.tribuo.Output; 020import org.tribuo.Prediction; 021import org.tribuo.evaluation.metrics.EvaluationMetric; 022import org.tribuo.evaluation.metrics.MetricContext; 023import org.tribuo.evaluation.metrics.MetricID; 024import org.tribuo.provenance.DataProvenance; 025import org.tribuo.provenance.EvaluationProvenance; 026 027import java.util.HashMap; 028import java.util.List; 029import java.util.Map; 030import java.util.Set; 031 032/** 033 * Base class for sequence evaluators. 034 */ 035public abstract class AbstractSequenceEvaluator< 036 T extends Output<T>, 037 C extends MetricContext<T>, 038 E extends SequenceEvaluation<T>, 039 M extends EvaluationMetric<T, C>> implements SequenceEvaluator<T, E> { 040 041 /** 042 * Produces an evaluation for the supplied model and dataset, by calling {@link SequenceModel#predict} 043 * to create the predictions, then aggregating the appropriate statistics. 044 * @param model The model to use. 045 * @param dataset The dataset to make predictions for. 046 * @return An evaluation of the dataset on the model. 047 */ 048 @Override 049 public final E evaluate(SequenceModel<T> model, SequenceDataset<T> dataset) { 050 // 051 // Run the model against the dataset to get predictions 052 List<List<Prediction<T>>> predictions = model.predict(dataset); 053 return evaluate(model, predictions, dataset.getProvenance()); 054 } 055 056 /** 057 * Produces an evaluation for the supplied model and datasource, by calling {@link SequenceModel#predict} 058 * to create the predictions, then aggregating the appropriate statistics. 059 * @param model The model to use. 060 * @param datasource The datasource to make predictions for. 061 * @return An evaluation of the datasource on the model. 062 */ 063 @Override 064 public final E evaluate(SequenceModel<T> model, SequenceDataSource<T> datasource) { 065 // 066 // Run the model against the datasource to get predictions 067 List<List<Prediction<T>>> predictions = model.predict(datasource); 068 return evaluate(model, predictions, datasource.getProvenance()); 069 } 070 071 // "template method" 072 073 /** 074 * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics. 075 * <p> 076 * Warning, this method cannot validate that the predictions were returned by the model in question. 077 * @param model The model to use. 078 * @param predictions The predictions to use. 079 * @param dataProvenance The provenance of the test data. 080 * @return An evaluation of the predictions. 081 */ 082 @Override 083 public final E evaluate(SequenceModel<T> model, List<List<Prediction<T>>> predictions, DataProvenance dataProvenance) { 084 // 085 // Create the provenance for the model and dataset 086 EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance); 087 // 088 // Create an evaluation context. The context stores all the information needed by the list of metrics plus might 089 // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context') 090 C context = createContext(model, predictions); 091 // 092 // "MODEL": Build the list of metrics to compute. 093 Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model); 094 // 095 // "CONTROLLER": For each metric in the list, compute the result. 096 Map<MetricID<T>, Double> results = computeResults(context, metrics); 097 // 098 // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users 099 return createEvaluation(context, results, provenance); 100 } 101 102 /** 103 * Computes each metric given the context. 104 * @param ctx The metric context (i.e., the sufficient statistics). 105 * @param metrics The metrics to compute. 106 * @return The value of each requested metric. 107 */ 108 protected Map<MetricID<T>, Double> computeResults(C ctx, Set<? extends EvaluationMetric<T, C>> metrics) { 109 Map<MetricID<T>, Double> results = new HashMap<>(); 110 for (EvaluationMetric<T, C> metric : metrics) { 111 MetricID<T> id = metric.getID(); 112 double value = metric.compute(ctx); 113 results.put(id, value); 114 } 115 return results; 116 } 117 118 /** 119 * Creates the appropriate set of metrics for this model, by querying for it's {@link org.tribuo.OutputInfo}. 120 * @param model The model to inspect. 121 * @return The set of metrics. 122 */ 123 protected abstract Set<M> createMetrics(SequenceModel<T> model); 124 125 // 126 // Note: the following two methods are abstract (plus the 'C' type parameter) to make memoization work smoothly, basically. 127 128 /** 129 * Create the context needed for evaluation. The context might store global properties or cache computation. 130 * @param model the model that will be evaluated 131 * @param predictions the predictions that will be evaluated 132 * @return the context for this model and its predictions 133 */ 134 protected abstract C createContext(SequenceModel<T> model, List<List<Prediction<T>>> predictions); 135 136 /** 137 * Create an evaluation for the given results 138 * @param context the context that was used to compute these results 139 * @param results the results 140 * @param provenance the provenance of the results (including information about the model and dataset) 141 * @return the evaluation 142 */ 143 protected abstract E createEvaluation(C context, 144 Map<MetricID<T>, Double> results, 145 EvaluationProvenance provenance); 146}