001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import java.io.Serializable;
020import java.util.Collections;
021import java.util.LinkedHashMap;
022import java.util.Map;
023
024/**
025 * A prediction made by a {@link Model}.
026 * Contains the output, and optionally and a map of scores over the possible outputs.
027 * If hasProbabilities() == true then it has a probability
028 * distribution over outputs otherwise it is unnormalized scores over outputs.
029 * <p>
030 * If possible it also contains the number of features that were used to make a prediction,
031 * and how many features originally existed in the {@link Example}.
032 */
033public class Prediction<T extends Output<T>> implements Serializable {
034    private static final long serialVersionUID = 1L;
035
036    /**
037     * The example which was used to generate this prediction.
038     */
039    private final Example<T> example;
040
041    /**
042     * The output assigned by a classifier.
043     */
044    private final T output;
045
046    /**
047     * Does outputScores contain probabilities or scores?
048     */
049    private final boolean probability;
050
051    /**
052     * How many features were used by the model.
053     */
054    private final int numUsed;
055
056    /**
057     * How many features were set in the example.
058     */
059    private final int exampleSize;
060
061    /**
062     * A map from output name to output object, which contains the score.
063     */
064    private final Map<String,T> outputScores;
065
066    /**
067     * Constructs a prediction from the supplied arguments.
068     * @param output The predicted output (i.e., the one with the maximum score).
069     * @param outputScores The output score distribution.
070     * @param numUsed The number of features used to make the prediction.
071     * @param exampleSize The size of the input example.
072     * @param example The input example.
073     * @param probability Are the scores probabilities?
074     */
075    private Prediction(T output, Map<String,T> outputScores, int numUsed, int exampleSize, Example<T> example, boolean probability) {
076        this.example = example;
077        this.outputScores = outputScores;
078        this.numUsed = numUsed;
079        this.exampleSize = exampleSize;
080        this.output = output;
081        this.probability = probability;
082    }
083
084    /**
085     * Constructs a prediction from the supplied arguments.
086     * @param output The predicted output (i.e., the one with the maximum score).
087     * @param outputScores The output score distribution.
088     * @param numUsed The number of features used to make the prediction.
089     * @param example The input example.
090     * @param probability Are the scores probabilities?
091     */
092    public Prediction(T output, Map<String,T> outputScores, int numUsed, Example<T> example, boolean probability) {
093        this(output,outputScores,numUsed,example.size(),example,probability);
094    }
095
096    /**
097     * Constructs a prediction from the supplied arguments.
098     * @param output The predicted output.
099     * @param numUsed The number of features used to make the prediction.
100     * @param example The input example.
101     */
102    public Prediction(T output, int numUsed, Example<T> example) {
103        this(output,Collections.emptyMap(),numUsed,example.size(),example,false);
104    }
105
106    /**
107     * Constructs a prediction from the supplied arguments.
108     * @param other The prediction to copy.
109     * @param numUsed The number of features used to make the prediction.
110     * @param example The input example.
111     */
112    public Prediction(Prediction<T> other, int numUsed, Example<T> example) {
113        this(other.output,new LinkedHashMap<>(other.outputScores),numUsed,example.size(),example,other.probability);
114    }
115
116    /**
117     * Returns the predicted output.
118     * @return The predicted output.
119     */
120    public T getOutput() {
121        return output;
122    }
123
124    /**
125     * Returns the number of features used in the prediction.
126     * @return The number of features used.
127     */
128    public int getNumActiveFeatures() {
129        return numUsed;
130    }
131
132    /**
133     * Returns the number of features in the example.
134     * @return The number of features in the example.
135     */
136    public int getExampleSize() {
137        return exampleSize;
138    }
139
140    /**
141     * Returns the example itself.
142     * @return The example.
143     */
144    public Example<T> getExample() {
145        return example;
146    }
147
148    /**
149     * Gets the output scores for each output.
150     * <p>
151     * May be an empty map if it did not generate scores.
152     * @return A Map.
153     */
154    public Map<String,T> getOutputScores() {
155        return outputScores;
156    }
157
158    /**
159     * Are the scores probabilities?
160     * @return True if the scores are probabilities.
161     */
162    public boolean hasProbabilities() {
163        return probability;
164    }
165
166    @Override
167    public String toString() {
168        StringBuilder buffer = new StringBuilder();
169
170        buffer.append("Prediction(maxLabel=");
171        buffer.append(output);
172        buffer.append(",outputScores={");
173        for (Map.Entry<String,T> e : outputScores.entrySet()) {
174            buffer.append(e.toString());
175        }
176        buffer.delete(buffer.length()-2,buffer.length());
177        buffer.append("})");
178
179        return buffer.toString();
180    }
181}