001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import org.tribuo.DataSource;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.Model;
023import org.tribuo.Output;
024import org.tribuo.Prediction;
025import org.tribuo.impl.ArrayExample;
026import org.tribuo.provenance.DataProvenance;
027
028import java.util.ArrayList;
029import java.util.List;
030
031/**
032 * An evaluation factory which produces immutable {@link Evaluation}s of a given {@link Dataset} using the given {@link Model}.
033 * <p>
034 * If the dataset contains an unknown output (as generated by {@link org.tribuo.OutputFactory#getUnknownOutput()})
035 * or a valid output which is outside of the domain of the {@link Model} then the evaluate methods will
036 * throw {@link IllegalArgumentException} with an appropriate message.
037 * @param <T> The output type.
038 * @param <E> The evaluation type.
039 */
040public interface Evaluator<T extends Output<T>, E extends Evaluation<T>> {
041
042    /**
043     * Evaluates the dataset using the supplied model, returning an immutable {@link Evaluation} of the appropriate type.
044     * @param model The model to use.
045     * @param dataset The dataset to evaluate.
046     * @return An evaluation.
047     */
048    public E evaluate(Model<T> model, Dataset<T> dataset);
049
050    /**
051     * Evaluates the dataset using the supplied model, returning an immutable {@link Evaluation} of the appropriate type.
052     * @param model The model to use.
053     * @param datasource The data to evaluate.
054     * @return An evaluation.
055     */
056    public E evaluate(Model<T> model, DataSource<T> datasource);
057
058    /**
059     * Evaluates the model performance using the supplied predictions, returning an immutable {@link Evaluation}
060     * of the appropriate type.
061     * <p>
062     * It does not validate that the {@code model} produced the supplied {@code predictions}, or that
063     * the {@code dataProvenance} matches the input examples. Supplying arguments which do not meet
064     * these invariants will produce an invalid Evaluation.
065     * </p>
066     * @param model The model to use.
067     * @param predictions The predictions to evaluate.
068     * @param dataProvenance The provenance of the predicted dataset.
069     * @return An evaluation.
070     */
071    public E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance);
072
073    /**
074     * Evaluates the model performance using the supplied predictions, returning an immutable {@link Evaluation}
075     * of the appropriate type.
076     * <p>
077     * This method is used when the predictions do not contain the correct ground truth labels (e.g., if they
078     * were collected separately from the examples constructed for prediction). First it creates a new set of
079     * predictions, containing the same examples with the matched ground truth prediction.
080     * Then it calls {@link Evaluator#evaluate(Model, List, DataProvenance)} with the updated predictions.
081     * <p>
082     * It does not validate that the {@code model} produced the supplied {@code predictions}, or that
083     * the {@code dataProvenance} matches the input examples. Supplying arguments which do not meet
084     * these invariants will produce an invalid Evaluation.
085     * @param model The model to use.
086     * @param predictions The predictions to evaluate.
087     * @param groundTruth The ground truth outputs to use.
088     * @param dataProvenance The provenance of the predicted dataset.
089     * @return An evaluation.
090     */
091    default public E evaluate(Model<T> model, List<Prediction<T>> predictions, List<T> groundTruth, DataProvenance dataProvenance) {
092        if (predictions.size() != groundTruth.size()) {
093            throw new IllegalArgumentException(
094                    "Predictions and ground truth must be the same length, received predictions.size()="
095                            +predictions.size()+", groundTruth.size()="+groundTruth.size());
096        }
097        List<Prediction<T>> newPredictions = new ArrayList<>(predictions.size());
098
099        for (int i = 0; i < predictions.size(); i++) {
100            Prediction<T> curPrediction = predictions.get(i);
101            Example<T> curExample = curPrediction.getExample();
102            ArrayExample<T> newExample = new ArrayExample<>(groundTruth.get(i), curExample, curExample.getWeight());
103            Prediction<T> newPrediction = new Prediction<>(curPrediction,curPrediction.getNumActiveFeatures(),newExample);
104            newPredictions.add(newPrediction);
105        }
106
107        return evaluate(model,newPredictions,dataProvenance);
108    }
109
110    /**
111     * Creates an online evaluator that maintains a list of all the predictions it has seen and can evaluate
112     * them upon request.
113     * @param model The model to use for online evaluation.
114     * @param provenance The provenance of the data.
115     * @return An online evaluator.
116     */
117    default public OnlineEvaluator<T,E> createOnlineEvaluator(Model<T> model, DataProvenance provenance) {
118        return new OnlineEvaluator<>(this,model,provenance);
119    }
120}