001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import org.tribuo.DataSource; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.Model; 023import org.tribuo.Output; 024import org.tribuo.Prediction; 025import org.tribuo.impl.ArrayExample; 026import org.tribuo.provenance.DataProvenance; 027 028import java.util.ArrayList; 029import java.util.List; 030 031/** 032 * An evaluation factory which produces immutable {@link Evaluation}s of a given {@link Dataset} using the given {@link Model}. 033 * <p> 034 * If the dataset contains an unknown output (as generated by {@link org.tribuo.OutputFactory#getUnknownOutput()}) 035 * or a valid output which is outside of the domain of the {@link Model} then the evaluate methods will 036 * throw {@link IllegalArgumentException} with an appropriate message. 037 * @param <T> The output type. 038 * @param <E> The evaluation type. 039 */ 040public interface Evaluator<T extends Output<T>, E extends Evaluation<T>> { 041 042 /** 043 * Evaluates the dataset using the supplied model, returning an immutable {@link Evaluation} of the appropriate type. 044 * @param model The model to use. 045 * @param dataset The dataset to evaluate. 046 * @return An evaluation. 047 */ 048 public E evaluate(Model<T> model, Dataset<T> dataset); 049 050 /** 051 * Evaluates the dataset using the supplied model, returning an immutable {@link Evaluation} of the appropriate type. 052 * @param model The model to use. 053 * @param datasource The data to evaluate. 054 * @return An evaluation. 055 */ 056 public E evaluate(Model<T> model, DataSource<T> datasource); 057 058 /** 059 * Evaluates the model performance using the supplied predictions, returning an immutable {@link Evaluation} 060 * of the appropriate type. 061 * <p> 062 * It does not validate that the {@code model} produced the supplied {@code predictions}, or that 063 * the {@code dataProvenance} matches the input examples. Supplying arguments which do not meet 064 * these invariants will produce an invalid Evaluation. 065 * </p> 066 * @param model The model to use. 067 * @param predictions The predictions to evaluate. 068 * @param dataProvenance The provenance of the predicted dataset. 069 * @return An evaluation. 070 */ 071 public E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance); 072 073 /** 074 * Evaluates the model performance using the supplied predictions, returning an immutable {@link Evaluation} 075 * of the appropriate type. 076 * <p> 077 * This method is used when the predictions do not contain the correct ground truth labels (e.g., if they 078 * were collected separately from the examples constructed for prediction). First it creates a new set of 079 * predictions, containing the same examples with the matched ground truth prediction. 080 * Then it calls {@link Evaluator#evaluate(Model, List, DataProvenance)} with the updated predictions. 081 * <p> 082 * It does not validate that the {@code model} produced the supplied {@code predictions}, or that 083 * the {@code dataProvenance} matches the input examples. Supplying arguments which do not meet 084 * these invariants will produce an invalid Evaluation. 085 * @param model The model to use. 086 * @param predictions The predictions to evaluate. 087 * @param groundTruth The ground truth outputs to use. 088 * @param dataProvenance The provenance of the predicted dataset. 089 * @return An evaluation. 090 */ 091 default public E evaluate(Model<T> model, List<Prediction<T>> predictions, List<T> groundTruth, DataProvenance dataProvenance) { 092 if (predictions.size() != groundTruth.size()) { 093 throw new IllegalArgumentException( 094 "Predictions and ground truth must be the same length, received predictions.size()=" 095 +predictions.size()+", groundTruth.size()="+groundTruth.size()); 096 } 097 List<Prediction<T>> newPredictions = new ArrayList<>(predictions.size()); 098 099 for (int i = 0; i < predictions.size(); i++) { 100 Prediction<T> curPrediction = predictions.get(i); 101 Example<T> curExample = curPrediction.getExample(); 102 ArrayExample<T> newExample = new ArrayExample<>(groundTruth.get(i), curExample, curExample.getWeight()); 103 Prediction<T> newPrediction = new Prediction<>(curPrediction,curPrediction.getNumActiveFeatures(),newExample); 104 newPredictions.add(newPrediction); 105 } 106 107 return evaluate(model,newPredictions,dataProvenance); 108 } 109 110 /** 111 * Creates an online evaluator that maintains a list of all the predictions it has seen and can evaluate 112 * them upon request. 113 * @param model The model to use for online evaluation. 114 * @param provenance The provenance of the data. 115 * @return An online evaluator. 116 */ 117 default public OnlineEvaluator<T,E> createOnlineEvaluator(Model<T> model, DataProvenance provenance) { 118 return new OnlineEvaluator<>(this,model,provenance); 119 } 120}