001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import org.tribuo.Example;
020import org.tribuo.Model;
021import org.tribuo.Output;
022import org.tribuo.Prediction;
023import org.tribuo.provenance.DataProvenance;
024
025import java.util.ArrayList;
026import java.util.List;
027
028/**
029 * An evaluator which aggregates predictions and produces {@link Evaluation}s
030 * covering all the {@link Prediction}s it has seen or created.
031 * @param <T> The output type.
032 * @param <E> The evaluation type.
033 */
034public final class OnlineEvaluator<T extends Output<T>, E extends Evaluation<T>> {
035
036    private final Evaluator<T,E> evaluator;
037    private final Model<T> model;
038    private final DataProvenance provenance;
039
040    private final List<Prediction<T>> predictions = new ArrayList<>();
041
042    /**
043     * Constructs an {@code OnlineEvaluator} which accumulates predictions.
044     * @param evaluator The evaluator to use to make {@link Evaluation}s.
045     * @param model The model to use.
046     * @param provenance The provenance of the evaluation data.
047     */
048    public OnlineEvaluator(Evaluator<T,E> evaluator, Model<T> model, DataProvenance provenance) {
049        this.evaluator = evaluator;
050        this.model = model;
051        this.provenance = provenance;
052    }
053
054    /**
055     * Creates an {@link Evaluation} containing all the current
056     * predictions.
057     * @return An {@link Evaluation} of the appropriate type.
058     */
059    public E evaluate() {
060        return evaluator.evaluate(model,new ArrayList<>(predictions),provenance);
061    }
062
063    /**
064     * Feeds the example to the model, records the prediction and returns it.
065     * @param example The example to predict.
066     * @return The model prediction for this example.
067     */
068    public synchronized Prediction<T> predictAndObserve(Example<T> example) {
069        Prediction<T> cur = model.predict(example);
070        predictions.add(cur);
071        return cur;
072    }
073
074    /**
075     * Feeds the examples to the model, records the predictions and returns them.
076     * @param examples The examples to predict.
077     * @return The model predictions for the supplied examples.
078     */
079    public synchronized List<Prediction<T>> predictAndObserve(Iterable<Example<T>> examples) {
080        List<Prediction<T>> cur = model.predict(examples);
081        predictions.addAll(cur);
082        return new ArrayList<>(cur);
083    }
084
085    /**
086     * Records the supplied prediction.
087     * @param newPrediction The prediction to record.
088     */
089    public synchronized void observe(Prediction<T> newPrediction) {
090        predictions.add(newPrediction);
091    }
092
093    /**
094     * Records all the supplied predictions.
095     * @param newPredictions The predictions to record.
096     */
097    public synchronized void observe(List<Prediction<T>> newPredictions) {
098        predictions.addAll(newPredictions);
099    }
100}