001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import org.tribuo.DataSource;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.Model;
023import org.tribuo.Output;
024import org.tribuo.OutputFactory;
025import org.tribuo.Prediction;
026import org.tribuo.evaluation.metrics.EvaluationMetric;
027import org.tribuo.evaluation.metrics.MetricContext;
028import org.tribuo.evaluation.metrics.MetricID;
029import org.tribuo.provenance.DataProvenance;
030import org.tribuo.provenance.EvaluationProvenance;
031
032import java.util.ArrayList;
033import java.util.HashMap;
034import java.util.List;
035import java.util.Map;
036import java.util.Set;
037
038/**
039 * Base class for evaluators.
040 */
041public abstract class AbstractEvaluator<
042        T extends Output<T>,
043        C extends MetricContext<T>,
044        E extends Evaluation<T>,
045        M extends EvaluationMetric<T, C>> implements Evaluator<T, E> {
046
047    /**
048     * Produces an evaluation for the supplied model and dataset, by calling {@link Model#predict}
049     * to create the predictions, then aggregating the appropriate statistics.
050     * @param model The model to use.
051     * @param dataset The dataset to make predictions for.
052     * @return An evaluation of the dataset on the model.
053     */
054    @Override
055    public final E evaluate(Model<T> model, Dataset<T> dataset) {
056        OutputFactory<T> factory = dataset.getOutputFactory();
057        int i = 0;
058        for (Example<T> example : dataset) {
059            if (factory.getUnknownOutput().equals(example.getOutput())) {
060                throw new IllegalArgumentException("The sentinel Unknown Output was used as a ground truth label in example number " + i);
061            }
062            i++;
063        }
064        //
065        // Run the model against the dataset to get predictions
066        List<Prediction<T>> predictions = model.predict(dataset);
067        return evaluate(model, predictions, dataset.getProvenance());
068    }
069
070    /**
071     * Produces an evaluation for the supplied model and datasource, by calling {@link Model#predict}
072     * to create the predictions, then aggregating the appropriate statistics.
073     * @param model The model to use.
074     * @param datasource The datasource to make predictions for.
075     * @return An evaluation of the datasource on the model.
076     */
077    @Override
078    public final E evaluate(Model<T> model, DataSource<T> datasource) {
079        OutputFactory<T> factory = datasource.getOutputFactory();
080        List<Example<T>> examples = new ArrayList<>();
081        for (Example<T> example : datasource) {
082            if (factory.getUnknownOutput().equals(example.getOutput())) {
083                throw new IllegalArgumentException("The sentinel Unknown Output was used as a ground truth label in example number " + examples.size());
084            }
085            examples.add(example);
086        }
087        //
088        // Run the model against the dataset to get predictions
089        List<Prediction<T>> predictions = model.predict(examples);
090        return evaluate(model, predictions, datasource.getProvenance());
091    }
092
093    // "template method"
094
095    /**
096     * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
097     * <p>
098     * Warning, this method cannot validate that the predictions were returned by the model in question.
099     * @param model The model to use.
100     * @param predictions The predictions to use.
101     * @param dataProvenance The provenance of the test data.
102     * @return An evaluation of the predictions.
103     */
104    @Override
105    public final E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance) {
106        //
107        // Create the provenance for the model and dataset
108        EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
109        //
110        // Create an evaluation context. The context stores all the information needed by the list of metrics plus might
111        // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
112        C context = createContext(model, predictions);
113        //
114        // "MODEL": Build the list of metrics to compute.
115        Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
116        //
117        // "CONTROLLER": For each metric in the list, compute the result.
118        Map<MetricID<T>, Double> results = computeResults(context, metrics);
119        //
120        // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
121        return createEvaluation(context, results, provenance);
122    }
123
124    /**
125     * Computes each metric given the context.
126     * @param ctx The metric context (i.e., the sufficient statistics).
127     * @param metrics The metrics to compute.
128     * @return The value of each requested metric.
129     */
130    protected Map<MetricID<T>, Double> computeResults(C ctx, Set<? extends EvaluationMetric<T, C>> metrics) {
131        Map<MetricID<T>, Double> results = new HashMap<>();
132        for (EvaluationMetric<T, C> metric : metrics) {
133            MetricID<T> id = metric.getID();
134            double value = metric.compute(ctx);
135            results.put(id, value);
136        }
137        return results;
138    }
139
140    /**
141     * Creates the appropriate set of metrics for this model, by querying for it's {@link org.tribuo.OutputInfo}.
142     * @param model The model to inspect.
143     * @return The set of metrics.
144     */
145    protected abstract Set<M> createMetrics(Model<T> model);
146
147    //
148    // Note: the following two methods are abstract (plus the 'C' type parameter) to make memoization work smoothly, basically.
149
150    /**
151     * Create the context needed for evaluation. The context might store global properties or cache computation.
152     * @param model the model that will be evaluated
153     * @param predictions the predictions that will be evaluated
154     * @return the context for this model and its predictions
155     */
156    protected abstract C createContext(Model<T> model, List<Prediction<T>> predictions);
157
158    /**
159     * Create an evaluation for the given results
160     * @param context the context that was used to compute these results
161     * @param results the results
162     * @param provenance the provenance of the results (including information about the model and dataset)
163     * @return the evaluation
164     */
165    protected abstract E createEvaluation(C context,
166                                          Map<MetricID<T>, Double> results,
167                                          EvaluationProvenance provenance);
168}