001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Model; 022import org.tribuo.Output; 023import org.tribuo.Prediction; 024import org.tribuo.evaluation.metrics.EvaluationMetric; 025import org.tribuo.evaluation.metrics.MetricContext; 026import org.tribuo.evaluation.metrics.MetricID; 027import org.tribuo.util.Util; 028 029import java.util.HashMap; 030import java.util.List; 031import java.util.Map; 032import java.util.function.Function; 033import java.util.function.ToDoubleFunction; 034import java.util.stream.Collectors; 035 036 037/** 038 * Aggregates metrics from a list of evaluations, or a list of models and datasets. 039 */ 040public final class EvaluationAggregator { 041 042 // singleton 043 private EvaluationAggregator() {} 044 045 /** 046 * Summarize performance w.r.t. metric across several models on a single dataset. 047 * @param metric The metric to summarise. 048 * @param models The models to evaluate. 049 * @param dataset The dataset to evaluate. 050 * @param <T> The output type. 051 * @param <C> The context type used for this metric. 052 * @return The descriptive statistics for this metric summary. 053 */ 054 public static <T extends Output<T>, 055 C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T,C> metric, List<? extends Model<T>> models, Dataset<T> dataset) { 056 DescriptiveStats summary = new DescriptiveStats(); 057 for (Model<T> model : models) { 058 C ctx = metric.createContext(model, dataset); 059 double value = metric.compute(ctx); 060 summary.addValue(value); 061 } 062 return summary; 063 } 064 065 /** 066 * Summarize performance using the supplied evaluator across several models on one dataset. 067 * @param evaluator The evaluator to use. 068 * @param models The models to evaluate. 069 * @param dataset The dataset to evaluate. 070 * @param <T> The output type. 071 * @param <R> The evaluation type. 072 * @return Descriptive statistics for each metric in the evaluator. 073 */ 074 public static <T extends Output<T>, 075 R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T,R> evaluator, List<? extends Model<T>> models, Dataset<T> dataset) { 076 List<R> evals = models.stream().map(model -> evaluator.evaluate(model, dataset)).collect(Collectors.toList()); 077 return summarize(evals); 078 } 079 080 /** 081 * Summarize a model's performance w.r.t. a metric across several datasets. 082 * 083 * @param metric The metric to evaluate. 084 * @param model The model to evaluate. 085 * @param datasets The datasets to evaluate. 086 * @param <T> The output type. 087 * @param <C> The metric context type. 088 * @return Descriptive statistics for the metric across the datasets. 089 */ 090 public static <T extends Output<T>, 091 C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T,C> metric, Model<T> model, List<? extends Dataset<T>> datasets) { 092 DescriptiveStats summary = new DescriptiveStats(); 093 for (Dataset<T> dataset : datasets) { 094 C ctx = metric.createContext(model, dataset); 095 double value = metric.compute(ctx); 096 summary.addValue(value); 097 } 098 return summary; 099 } 100 101 /** 102 * Summarize model performance on dataset across several metrics. 103 * @param metrics The metrics to evaluate. 104 * @param model The model to evaluate them on. 105 * @param dataset The dataset to evaluate them on. 106 * @param <T> The output type. 107 * @param <C> The metric context type. 108 * @return The descriptive statistics for the metrics. 109 */ 110 public static <T extends Output<T>, 111 C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T,C>> metrics, Model<T> model, Dataset<T> dataset) { 112 List<Prediction<T>> predictions = model.predict(dataset); 113 DescriptiveStats summary = new DescriptiveStats(); 114 for (EvaluationMetric<T,C> metric : metrics) { 115 C ctx = metric.createContext(model, predictions); 116 double value = metric.compute(ctx); 117 summary.addValue(value); 118 } 119 return summary; 120 } 121 122 /** 123 * Summarize model performance on dataset across several metrics. 124 * @param metrics The metrics to evaluate. 125 * @param model The model to evaluate them on. 126 * @param predictions The predictions to evaluate. 127 * @param <T> The output type. 128 * @param <C> The metric context type. 129 * @return The descriptive statistics for the metrics. 130 */ 131 public static <T extends Output<T>, 132 C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T,C>> metrics, Model<T> model, List<Prediction<T>> predictions) { 133 DescriptiveStats summary = new DescriptiveStats(); 134 for (EvaluationMetric<T,C> metric : metrics) { 135 C ctx = metric.createContext(model, predictions); 136 double value = metric.compute(ctx); 137 summary.addValue(value); 138 } 139 return summary; 140 } 141 142 /** 143 * Summarize performance according to evaluator for a single model across several datasets. 144 * @param evaluator The evaluator to use. 145 * @param model The model to evaluate. 146 * @param datasets The datasets to evaluate across. 147 * @param <T> The output type. 148 * @param <R> The evaluation type. 149 * @return The descriptive statistics for each metric. 150 */ 151 public static <T extends Output<T>, 152 R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T,R> evaluator, Model<T> model, List<? extends Dataset<T>> datasets) { 153 List<R> evals = datasets.stream().map(data -> evaluator.evaluate(model, data)).collect(Collectors.toList()); 154 return summarize(evals); 155 } 156 157 /** 158 * Summarize all fields of a list of evaluations. 159 * @param evaluations The evaluations to summarize. 160 * @param <T> The output type. 161 * @param <R> The evaluation type. 162 * @return The descriptive statistics for each metric. 163 */ 164 public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(List<R> evaluations) { 165 Map<MetricID<T>, DescriptiveStats> results = new HashMap<>(); 166 for (R evaluation : evaluations) { 167 for (Map.Entry<MetricID<T>, Double> kv : evaluation.asMap().entrySet()) { 168 MetricID<T> key = kv.getKey(); 169 DescriptiveStats summary = results.getOrDefault(key, new DescriptiveStats()); 170 summary.addValue(kv.getValue()); 171 results.put(key, summary); 172 } 173 } 174 return results; 175 } 176 177 /** 178 * Summarize a single field of an evaluation across several evaluations. 179 * 180 * @param evaluations the evaluations 181 * @param fieldGetter the getter for the field to summarize 182 * @param <T> the type of the output 183 * @param <R> the type of the evaluation 184 * @return a descriptive stats summary of field 185 */ 186 public static <T extends Output<T>, R extends Evaluation<T>> DescriptiveStats summarize(List<R> evaluations, ToDoubleFunction<R> fieldGetter) { 187 DescriptiveStats summary = new DescriptiveStats(); 188 for (R evaluation : evaluations) { 189 double value = fieldGetter.applyAsDouble(evaluation); 190 summary.addValue(value); 191 } 192 return summary; 193 } 194 195 /** 196 * Calculates the argmax of a metric across the supplied models (i.e., the index of the model which performed the best). 197 * @param metric The metric to evaluate. 198 * @param models The models to evaluate across. 199 * @param dataset The dataset to evaluate on. 200 * @param <T> The output type. 201 * @param <C> The metric context. 202 * @return The maximum value and it's index in the models list. 203 */ 204 public static <T extends Output<T>, 205 C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T,C> metric, List<? extends Model<T>> models, Dataset<T> dataset) { 206 List<Double> values = models.stream() 207 .map(model -> metric.compute(metric.createContext(model, dataset))) 208 .collect(Collectors.toList()); 209 return Util.argmax(values); 210 } 211 212 /** 213 * Calculates the argmax of a metric across the supplied datasets. 214 * @param metric The metric to evaluate. 215 * @param model The model to evaluate on. 216 * @param datasets The datasets to evaluate across. 217 * @param <T> The output type. 218 * @param <C> The metric context. 219 * @return The maximum value and it's index in the datasets list. 220 */ 221 public static <T extends Output<T>, 222 C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T,C> metric, Model<T> model, List<? extends Dataset<T>> datasets) { 223 List<Double> values = datasets.stream() 224 .map(dataset -> metric.compute(metric.createContext(model, dataset))) 225 .collect(Collectors.toList()); 226 return Util.argmax(values); 227 } 228 229 /** 230 * Calculates the argmax of a metric across the supplied evaluations. 231 * @param evaluations The evaluations. 232 * @param getter The function to extract a value from the evaluation. 233 * @param <T> The output type. 234 * @param <R> The evaluation type. 235 * @return The maximum value and it's index in the evaluations list. 236 */ 237 public static <T extends Output<T>, R extends Evaluation<T>> Pair<Integer, Double> argmax(List<R> evaluations, Function<R, Double> getter) { 238 List<Double> values = evaluations.stream().map(getter).collect(Collectors.toList()); 239 return Util.argmax(values); 240 } 241 242}