001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
020import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
021import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.provenance.ModelProvenance;
024
025import java.io.Serializable;
026import java.util.ArrayList;
027import java.util.List;
028import java.util.Map;
029import java.util.Optional;
030import java.util.Set;
031
032/**
033 * A prediction model, which is used to predict outputs for unseen instances.
034 * Model implementations must be serializable!
035 * <p>
036 * If two features map to the same id in the featureIDMap, then
037 * occurrences of those features will be combined at prediction time.
038 * <p>
039 * @param <T> the type of prediction produced by the model.
040 */
041public abstract class Model<T extends Output<T>> implements Provenancable<ModelProvenance>, Serializable {
042    private static final long serialVersionUID = 2L;
043
044    /**
045     * Used in getTopFeatures when the Model doesn't support per output feature lists.
046     */
047    public static final String ALL_OUTPUTS = "ALL_OUTPUTS";
048
049    /**
050     * Used to denote the bias feature in a linear model.
051     */
052    public static final String BIAS_FEATURE = "BIAS";
053
054    /**
055     * The model's name.
056     */
057    protected String name;
058
059    /**
060     * The model provenance.
061     */
062    protected final ModelProvenance provenance;
063
064    /**
065     * The cached toString of the model provenance.
066     * <p>
067     * Mostly cached so it appears in the serialized output and can be read by grepping the binary.
068     */
069    protected final String provenanceOutput;
070
071    /**
072     * The features this model knows about.
073     */
074    protected final ImmutableFeatureMap featureIDMap;
075
076    /**
077     * The outputs this model predicts.
078     */
079    protected final ImmutableOutputInfo<T> outputIDInfo;
080
081    /**
082     * Does this model generate probability distributions in the output.
083     */
084    protected final boolean generatesProbabilities;
085
086    /**
087     * Constructs a new model, storing the supplied fields.
088     * @param name The model name.
089     * @param provenance The model provenance.
090     * @param featureIDMap The features.
091     * @param outputIDInfo The possible outputs.
092     * @param generatesProbabilities Does this model emit probabilistic outputs.
093     */
094    protected Model(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities) {
095        this.name = name;
096        this.provenance = provenance;
097        this.provenanceOutput = provenance.toString();
098        this.featureIDMap = featureIDMap;
099        this.outputIDInfo = outputIDInfo;
100        this.generatesProbabilities = generatesProbabilities;
101    }
102
103    /**
104     * Returns the model name.
105     * @return The model name.
106     */
107    public String getName() {
108        return name;
109    }
110
111    /**
112     * Sets the model name.
113     * @param name The new model name.
114     */
115    public void setName(String name) {
116        this.name = name;
117    }
118 
119    @Override
120    public ModelProvenance getProvenance() {
121        return provenance;
122    }
123
124    /**
125     * Gets the feature domain.
126     * @return The feature domain.
127     */
128    public ImmutableFeatureMap getFeatureIDMap() {
129        return featureIDMap;
130    }
131
132    /**
133     * Gets the output domain.
134     * @return The output domain.
135     */
136    public ImmutableOutputInfo<T> getOutputIDInfo() {
137        return outputIDInfo;
138    }
139
140    /**
141     * Does this model generate probabilistic predictions.
142     * @return True if the model generates probabilistic predictions.
143     */
144    public boolean generatesProbabilities() {
145        return generatesProbabilities;
146    }
147
148    /**
149     * Validates that this Model does in fact support the supplied output type.
150     * <p>
151     * As the output type is erased at runtime, deserialising a Model is an unchecked
152     * operation. This method allows the user to check that the deserialised model is
153     * of the appropriate type, rather than seeing if {@link Model#predict} throws a {@link ClassCastException}
154     * when called.
155     * </p>
156     * @param clazz The class object to verify the output type against.
157     * @return True if the output type is assignable to the class object type, false otherwise.
158     */
159    public boolean validate(Class<? extends Output<?>> clazz) {
160        Set<T> domain = outputIDInfo.getDomain();
161        boolean output = true;
162        for (T type : domain) {
163            output &= clazz.isInstance(type);
164        }
165        return output;
166    }
167
168    /**
169     * Uses the model to predict the output for a single example.
170     * <p>
171     * predict does not mutate the example.
172     * <p>
173     * Throws {@link IllegalArgumentException} if the example has no features
174     * or no feature overlap with the model.
175     * @param example the example to predict.
176     * @return the result of the prediction.
177     */
178    public abstract Prediction<T> predict(Example<T> example);
179    
180    /**
181     * Uses the model to predict the output for multiple examples.
182     * <p>
183     * Throws {@link IllegalArgumentException} if the examples have no features
184     * or no feature overlap with the model.
185     * @param examples the examples to predict.
186     * @return the results of the prediction, in the same order as the 
187     * examples.
188     */
189    public List<Prediction<T>> predict(Iterable<Example<T>> examples) {
190        return innerPredict(examples);
191    }
192
193    /**
194     * Uses the model to predict the outputs for multiple examples contained in
195     * a data set.
196     * <p>
197     * Throws {@link IllegalArgumentException} if the examples have no features
198     * or no feature overlap with the model.
199     * @param examples the data set containing the examples to predict.
200     * @return the results of the predictions, in the same order as the 
201     * Dataset provides the examples.
202     */
203    public List<Prediction<T>> predict(Dataset<T> examples) {
204        return innerPredict(examples);
205    }
206
207    /**
208     * Called by the base implementations of {@link Model#predict(Iterable)} and {@link Model#predict(Dataset)}.
209     * @param examples The examples to predict.
210     * @return The results of the predictions, in the same order as the examples.
211     */
212    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) {
213        List<Prediction<T>> predictions = new ArrayList<>();
214        for (Example<T> example : examples) {
215            predictions.add(predict(example));
216        }
217        return predictions;
218    }
219    
220    /**
221     * Gets the top {@code n} features associated with this model.
222     * <p>
223     * If the model does not produce per output feature lists, it returns
224     * a map with a single element with key Model.ALL_OUTPUTS.
225     * </p>
226     * <p>
227     * If the model cannot describe it's top features then it returns {@link java.util.Collections#emptyMap}.
228     * </p>
229     * @param n the number of features to return. If this value is less than 0, 
230     * all features should be returned for each class, unless the model cannot score it's features.
231     * @return a map from string outputs to an ordered list of pairs of
232     * feature names and weights associated with that feature in the model
233     */
234    public abstract Map<String,List<Pair<String,Double>>> getTopFeatures(int n);
235
236    /**
237     * Generates an excuse for an example.
238     * <p>
239     * This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.
240     * <p>
241     * This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.
242     * <p>
243     * The optional is empty if the model does not provide excuses.
244     * @param example The input example.
245     * @return An optional excuse object. The optional is empty if this model does not provide excuses.
246     */
247    public abstract Optional<Excuse<T>> getExcuse(Example<T> example);
248
249    /**
250     * Generates an excuse for each example.
251     * <p>
252     * This may be an expensive operation, and probably should be overridden in subclasses for performance reasons.
253     * <p>
254     * These excuses either contain per class information or an entry with key Model.ALL_OUTPUTS.
255     * <p>
256     * The optional is empty if the model does not provide excuses.
257     * @param examples An iterable of examples
258     * @return A optional list of excuses. The Optional is empty if this model does not provide excuses.
259     */
260    public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> examples) {
261        List<Excuse<T>> excuses = new ArrayList<>();
262        for (Example<T> e : examples) {
263            Optional<Excuse<T>> excuse = getExcuse(e);
264            if (excuse.isPresent()) {
265                excuses.add(excuse.get());
266            } else {
267                return Optional.empty();
268            }
269        }
270        return Optional.of(excuses);
271    }
272
273    /**
274     * Copies a model, returning a deep copy of any mutable state, and a shallow copy otherwise.
275     * @return A copy of the model.
276     */
277    public Model<T> copy() {
278        List<ObjectMarshalledProvenance> omp = ProvenanceUtil.marshalProvenance(provenance);
279        ModelProvenance provenanceCopy = (ModelProvenance) ProvenanceUtil.unmarshalProvenance(omp);
280        return copy(name,provenanceCopy);
281    }
282
283    /**
284     * Copies a model, replacing it's provenance and name with the supplied values.
285     * <p>
286     * Used to provide the provenance removal functionality.
287     * @param newName The new name.
288     * @param newProvenance The new provenance.
289     * @return A copy of the model.
290     */
291    protected abstract Model<T> copy(String newName, ModelProvenance newProvenance);
292
293    @Override
294    public String toString() {
295        if (name != null && !name.isEmpty()) {
296            return name + " - " + provenanceOutput;
297        } else {
298            return provenanceOutput;
299        }
300    }
301    
302}