001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Model;
022import org.tribuo.Output;
023import org.tribuo.Prediction;
024import org.tribuo.evaluation.metrics.EvaluationMetric;
025import org.tribuo.evaluation.metrics.MetricContext;
026import org.tribuo.evaluation.metrics.MetricID;
027import org.tribuo.util.Util;
028
029import java.util.HashMap;
030import java.util.List;
031import java.util.Map;
032import java.util.function.Function;
033import java.util.function.ToDoubleFunction;
034import java.util.stream.Collectors;
035
036
037/**
038 * Aggregates metrics from a list of evaluations, or a list of models and datasets.
039 */
040public final class EvaluationAggregator {
041
042    // singleton
043    private EvaluationAggregator() {}
044
045    /**
046     * Summarize performance w.r.t. metric across several models on a single dataset.
047     * @param metric The metric to summarise.
048     * @param models The models to evaluate.
049     * @param dataset The dataset to evaluate.
050     * @param <T> The output type.
051     * @param <C> The context type used for this metric.
052     * @return The descriptive statistics for this metric summary.
053     */
054    public static <T extends Output<T>,
055            C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T,C> metric, List<? extends Model<T>> models, Dataset<T> dataset) {
056        DescriptiveStats summary = new DescriptiveStats();
057        for (Model<T> model : models) {
058            C ctx = metric.createContext(model, dataset);
059            double value = metric.compute(ctx);
060            summary.addValue(value);
061        }
062        return summary;
063    }
064
065    /**
066     * Summarize performance using the supplied evaluator across several models on one dataset.
067     * @param evaluator The evaluator to use.
068     * @param models The models to evaluate.
069     * @param dataset The dataset to evaluate.
070     * @param <T> The output type.
071     * @param <R> The evaluation type.
072     * @return Descriptive statistics for each metric in the evaluator.
073     */
074    public static <T extends Output<T>,
075            R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T,R> evaluator, List<? extends Model<T>> models, Dataset<T> dataset) {
076        List<R> evals = models.stream().map(model -> evaluator.evaluate(model, dataset)).collect(Collectors.toList());
077        return summarize(evals);
078    }
079
080    /**
081     * Summarize a model's performance w.r.t. a metric across several datasets.
082     *
083     * @param metric The metric to evaluate.
084     * @param model The model to evaluate.
085     * @param datasets The datasets to evaluate.
086     * @param <T> The output type.
087     * @param <C> The metric context type.
088     * @return Descriptive statistics for the metric across the datasets.
089     */
090    public static <T extends Output<T>,
091            C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T,C> metric, Model<T> model, List<? extends Dataset<T>> datasets) {
092        DescriptiveStats summary = new DescriptiveStats();
093        for (Dataset<T> dataset : datasets) {
094            C ctx = metric.createContext(model, dataset);
095            double value = metric.compute(ctx);
096            summary.addValue(value);
097        }
098        return summary;
099    }
100
101    /**
102     * Summarize model performance on dataset across several metrics.
103     * @param metrics The metrics to evaluate.
104     * @param model The model to evaluate them on.
105     * @param dataset The dataset to evaluate them on.
106     * @param <T> The output type.
107     * @param <C> The metric context type.
108     * @return The descriptive statistics for the metrics.
109     */
110    public static <T extends Output<T>,
111            C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T,C>> metrics, Model<T> model, Dataset<T> dataset) {
112        List<Prediction<T>> predictions = model.predict(dataset);
113        DescriptiveStats summary = new DescriptiveStats();
114        for (EvaluationMetric<T,C> metric : metrics) {
115            C ctx = metric.createContext(model, predictions);
116            double value = metric.compute(ctx);
117            summary.addValue(value);
118        }
119        return summary;
120    }
121
122    /**
123     * Summarize model performance on dataset across several metrics.
124     * @param metrics The metrics to evaluate.
125     * @param model The model to evaluate them on.
126     * @param predictions The predictions to evaluate.
127     * @param <T> The output type.
128     * @param <C> The metric context type.
129     * @return The descriptive statistics for the metrics.
130     */
131    public static <T extends Output<T>,
132            C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T,C>> metrics, Model<T> model, List<Prediction<T>> predictions) {
133        DescriptiveStats summary = new DescriptiveStats();
134        for (EvaluationMetric<T,C> metric : metrics) {
135            C ctx = metric.createContext(model, predictions);
136            double value = metric.compute(ctx);
137            summary.addValue(value);
138        }
139        return summary;
140    }
141
142    /**
143     * Summarize performance according to evaluator for a single model across several datasets.
144     * @param evaluator The evaluator to use.
145     * @param model The model to evaluate.
146     * @param datasets The datasets to evaluate across.
147     * @param <T> The output type.
148     * @param <R> The evaluation type.
149     * @return The descriptive statistics for each metric.
150     */
151    public static <T extends Output<T>,
152            R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T,R> evaluator, Model<T> model, List<? extends Dataset<T>> datasets) {
153        List<R> evals = datasets.stream().map(data -> evaluator.evaluate(model, data)).collect(Collectors.toList());
154        return summarize(evals);
155    }
156
157    /**
158     * Summarize all fields of a list of evaluations.
159     * @param evaluations The evaluations to summarize.
160     * @param <T> The output type.
161     * @param <R> The evaluation type.
162     * @return The descriptive statistics for each metric.
163     */
164    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(List<R> evaluations) {
165        Map<MetricID<T>, DescriptiveStats> results = new HashMap<>();
166        for (R evaluation : evaluations) {
167            for (Map.Entry<MetricID<T>, Double> kv : evaluation.asMap().entrySet()) {
168                MetricID<T> key = kv.getKey();
169                DescriptiveStats summary = results.getOrDefault(key, new DescriptiveStats());
170                summary.addValue(kv.getValue());
171                results.put(key, summary);
172            }
173        }
174        return results;
175    }
176
177    /**
178     * Summarize a single field of an evaluation across several evaluations.
179     *
180     * @param evaluations the evaluations
181     * @param fieldGetter the getter for the field to summarize
182     * @param <T> the type of the output
183     * @param <R> the type of the evaluation
184     * @return a descriptive stats summary of field
185     */
186    public static <T extends Output<T>, R extends Evaluation<T>> DescriptiveStats summarize(List<R> evaluations, ToDoubleFunction<R> fieldGetter) {
187        DescriptiveStats summary = new DescriptiveStats();
188        for (R evaluation : evaluations) {
189            double value = fieldGetter.applyAsDouble(evaluation);
190            summary.addValue(value);
191        }
192        return summary;
193    }
194
195    /**
196     * Calculates the argmax of a metric across the supplied models (i.e., the index of the model which performed the best).
197     * @param metric The metric to evaluate.
198     * @param models The models to evaluate across.
199     * @param dataset The dataset to evaluate on.
200     * @param <T> The output type.
201     * @param <C> The metric context.
202     * @return The maximum value and it's index in the models list.
203     */
204    public static <T extends Output<T>,
205            C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T,C> metric, List<? extends Model<T>> models, Dataset<T> dataset) {
206        List<Double> values = models.stream()
207                .map(model -> metric.compute(metric.createContext(model, dataset)))
208                .collect(Collectors.toList());
209        return Util.argmax(values);
210    }
211
212    /**
213     * Calculates the argmax of a metric across the supplied datasets.
214     * @param metric The metric to evaluate.
215     * @param model The model to evaluate on.
216     * @param datasets The datasets to evaluate across.
217     * @param <T> The output type.
218     * @param <C> The metric context.
219     * @return The maximum value and it's index in the datasets list.
220     */
221    public static <T extends Output<T>,
222            C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T,C> metric, Model<T> model, List<? extends Dataset<T>> datasets) {
223        List<Double> values = datasets.stream()
224                .map(dataset -> metric.compute(metric.createContext(model, dataset)))
225                .collect(Collectors.toList());
226        return Util.argmax(values);
227    }
228
229    /**
230     * Calculates the argmax of a metric across the supplied evaluations.
231     * @param evaluations The evaluations.
232     * @param getter The function to extract a value from the evaluation.
233     * @param <T> The output type.
234     * @param <R> The evaluation type.
235     * @return The maximum value and it's index in the evaluations list.
236     */
237    public static <T extends Output<T>, R extends Evaluation<T>> Pair<Integer, Double> argmax(List<R> evaluations, Function<R, Double> getter) {
238        List<Double> values = evaluations.stream().map(getter).collect(Collectors.toList());
239        return Util.argmax(values);
240    }
241
242}