001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import org.tribuo.Example; 020import org.tribuo.Model; 021import org.tribuo.Output; 022import org.tribuo.Prediction; 023import org.tribuo.provenance.DataProvenance; 024 025import java.util.ArrayList; 026import java.util.List; 027 028/** 029 * An evaluator which aggregates predictions and produces {@link Evaluation}s 030 * covering all the {@link Prediction}s it has seen or created. 031 * @param <T> The output type. 032 * @param <E> The evaluation type. 033 */ 034public final class OnlineEvaluator<T extends Output<T>, E extends Evaluation<T>> { 035 036 private final Evaluator<T,E> evaluator; 037 private final Model<T> model; 038 private final DataProvenance provenance; 039 040 private final List<Prediction<T>> predictions = new ArrayList<>(); 041 042 /** 043 * Constructs an {@code OnlineEvaluator} which accumulates predictions. 044 * @param evaluator The evaluator to use to make {@link Evaluation}s. 045 * @param model The model to use. 046 * @param provenance The provenance of the evaluation data. 047 */ 048 public OnlineEvaluator(Evaluator<T,E> evaluator, Model<T> model, DataProvenance provenance) { 049 this.evaluator = evaluator; 050 this.model = model; 051 this.provenance = provenance; 052 } 053 054 /** 055 * Creates an {@link Evaluation} containing all the current 056 * predictions. 057 * @return An {@link Evaluation} of the appropriate type. 058 */ 059 public E evaluate() { 060 return evaluator.evaluate(model,new ArrayList<>(predictions),provenance); 061 } 062 063 /** 064 * Feeds the example to the model, records the prediction and returns it. 065 * @param example The example to predict. 066 * @return The model prediction for this example. 067 */ 068 public synchronized Prediction<T> predictAndObserve(Example<T> example) { 069 Prediction<T> cur = model.predict(example); 070 predictions.add(cur); 071 return cur; 072 } 073 074 /** 075 * Feeds the examples to the model, records the predictions and returns them. 076 * @param examples The examples to predict. 077 * @return The model predictions for the supplied examples. 078 */ 079 public synchronized List<Prediction<T>> predictAndObserve(Iterable<Example<T>> examples) { 080 List<Prediction<T>> cur = model.predict(examples); 081 predictions.addAll(cur); 082 return new ArrayList<>(cur); 083 } 084 085 /** 086 * Records the supplied prediction. 087 * @param newPrediction The prediction to record. 088 */ 089 public synchronized void observe(Prediction<T> newPrediction) { 090 predictions.add(newPrediction); 091 } 092 093 /** 094 * Records all the supplied predictions. 095 * @param newPredictions The predictions to record. 096 */ 097 public synchronized void observe(List<Prediction<T>> newPredictions) { 098 predictions.addAll(newPredictions); 099 } 100}