001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import org.tribuo.Output;
020import org.tribuo.Prediction;
021import org.tribuo.evaluation.metrics.EvaluationMetric;
022import org.tribuo.evaluation.metrics.MetricContext;
023import org.tribuo.evaluation.metrics.MetricID;
024import org.tribuo.provenance.DataProvenance;
025import org.tribuo.provenance.EvaluationProvenance;
026
027import java.util.HashMap;
028import java.util.List;
029import java.util.Map;
030import java.util.Set;
031
032/**
033 * Base class for sequence evaluators.
034 */
035public abstract class AbstractSequenceEvaluator<
036        T extends Output<T>,
037        C extends MetricContext<T>,
038        E extends SequenceEvaluation<T>,
039        M extends EvaluationMetric<T, C>> implements SequenceEvaluator<T, E> {
040
041    /**
042     * Produces an evaluation for the supplied model and dataset, by calling {@link SequenceModel#predict}
043     * to create the predictions, then aggregating the appropriate statistics.
044     * @param model The model to use.
045     * @param dataset The dataset to make predictions for.
046     * @return An evaluation of the dataset on the model.
047     */
048    @Override
049    public final E evaluate(SequenceModel<T> model, SequenceDataset<T> dataset) {
050        //
051        // Run the model against the dataset to get predictions
052        List<List<Prediction<T>>> predictions = model.predict(dataset);
053        return evaluate(model, predictions, dataset.getProvenance());
054    }
055
056    /**
057     * Produces an evaluation for the supplied model and datasource, by calling {@link SequenceModel#predict}
058     * to create the predictions, then aggregating the appropriate statistics.
059     * @param model The model to use.
060     * @param datasource The datasource to make predictions for.
061     * @return An evaluation of the datasource on the model.
062     */
063    @Override
064    public final E evaluate(SequenceModel<T> model, SequenceDataSource<T> datasource) {
065        //
066        // Run the model against the datasource to get predictions
067        List<List<Prediction<T>>> predictions = model.predict(datasource);
068        return evaluate(model, predictions, datasource.getProvenance());
069    }
070
071    // "template method"
072
073    /**
074     * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
075     * <p>
076     * Warning, this method cannot validate that the predictions were returned by the model in question.
077     * @param model The model to use.
078     * @param predictions The predictions to use.
079     * @param dataProvenance The provenance of the test data.
080     * @return An evaluation of the predictions.
081     */
082    @Override
083    public final E evaluate(SequenceModel<T> model, List<List<Prediction<T>>> predictions, DataProvenance dataProvenance) {
084        //
085        // Create the provenance for the model and dataset
086        EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
087        //
088        // Create an evaluation context. The context stores all the information needed by the list of metrics plus might
089        // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
090        C context = createContext(model, predictions);
091        //
092        // "MODEL": Build the list of metrics to compute.
093        Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
094        //
095        // "CONTROLLER": For each metric in the list, compute the result.
096        Map<MetricID<T>, Double> results = computeResults(context, metrics);
097        //
098        // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
099        return createEvaluation(context, results, provenance);
100    }
101
102    /**
103     * Computes each metric given the context.
104     * @param ctx The metric context (i.e., the sufficient statistics).
105     * @param metrics The metrics to compute.
106     * @return The value of each requested metric.
107     */
108    protected Map<MetricID<T>, Double> computeResults(C ctx, Set<? extends EvaluationMetric<T, C>> metrics) {
109        Map<MetricID<T>, Double> results = new HashMap<>();
110        for (EvaluationMetric<T, C> metric : metrics) {
111            MetricID<T> id = metric.getID();
112            double value = metric.compute(ctx);
113            results.put(id, value);
114        }
115        return results;
116    }
117
118    /**
119     * Creates the appropriate set of metrics for this model, by querying for it's {@link org.tribuo.OutputInfo}.
120     * @param model The model to inspect.
121     * @return The set of metrics.
122     */
123    protected abstract Set<M> createMetrics(SequenceModel<T> model);
124
125    //
126    // Note: the following two methods are abstract (plus the 'C' type parameter) to make memoization work smoothly, basically.
127
128    /**
129     * Create the context needed for evaluation. The context might store global properties or cache computation.
130     * @param model the model that will be evaluated
131     * @param predictions the predictions that will be evaluated
132     * @return the context for this model and its predictions
133     */
134    protected abstract C createContext(SequenceModel<T> model, List<List<Prediction<T>>> predictions);
135
136    /**
137     * Create an evaluation for the given results
138     * @param context the context that was used to compute these results
139     * @param results the results
140     * @param provenance the provenance of the results (including information about the model and dataset)
141     * @return the evaluation
142     */
143    protected abstract E createEvaluation(C context,
144                                          Map<MetricID<T>, Double> results,
145                                          EvaluationProvenance provenance);
146}