001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 020import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 021import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.provenance.ModelProvenance; 024 025import java.io.Serializable; 026import java.util.ArrayList; 027import java.util.List; 028import java.util.Map; 029import java.util.Optional; 030import java.util.Set; 031 032/** 033 * A prediction model, which is used to predict outputs for unseen instances. 034 * Model implementations must be serializable! 035 * <p> 036 * If two features map to the same id in the featureIDMap, then 037 * occurrences of those features will be combined at prediction time. 038 * <p> 039 * @param <T> the type of prediction produced by the model. 040 */ 041public abstract class Model<T extends Output<T>> implements Provenancable<ModelProvenance>, Serializable { 042 private static final long serialVersionUID = 2L; 043 044 /** 045 * Used in getTopFeatures when the Model doesn't support per output feature lists. 046 */ 047 public static final String ALL_OUTPUTS = "ALL_OUTPUTS"; 048 049 /** 050 * Used to denote the bias feature in a linear model. 051 */ 052 public static final String BIAS_FEATURE = "BIAS"; 053 054 /** 055 * The model's name. 056 */ 057 protected String name; 058 059 /** 060 * The model provenance. 061 */ 062 protected final ModelProvenance provenance; 063 064 /** 065 * The cached toString of the model provenance. 066 * <p> 067 * Mostly cached so it appears in the serialized output and can be read by grepping the binary. 068 */ 069 protected final String provenanceOutput; 070 071 /** 072 * The features this model knows about. 073 */ 074 protected final ImmutableFeatureMap featureIDMap; 075 076 /** 077 * The outputs this model predicts. 078 */ 079 protected final ImmutableOutputInfo<T> outputIDInfo; 080 081 /** 082 * Does this model generate probability distributions in the output. 083 */ 084 protected final boolean generatesProbabilities; 085 086 /** 087 * Constructs a new model, storing the supplied fields. 088 * @param name The model name. 089 * @param provenance The model provenance. 090 * @param featureIDMap The features. 091 * @param outputIDInfo The possible outputs. 092 * @param generatesProbabilities Does this model emit probabilistic outputs. 093 */ 094 protected Model(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities) { 095 this.name = name; 096 this.provenance = provenance; 097 this.provenanceOutput = provenance.toString(); 098 this.featureIDMap = featureIDMap; 099 this.outputIDInfo = outputIDInfo; 100 this.generatesProbabilities = generatesProbabilities; 101 } 102 103 /** 104 * Returns the model name. 105 * @return The model name. 106 */ 107 public String getName() { 108 return name; 109 } 110 111 /** 112 * Sets the model name. 113 * @param name The new model name. 114 */ 115 public void setName(String name) { 116 this.name = name; 117 } 118 119 @Override 120 public ModelProvenance getProvenance() { 121 return provenance; 122 } 123 124 /** 125 * Gets the feature domain. 126 * @return The feature domain. 127 */ 128 public ImmutableFeatureMap getFeatureIDMap() { 129 return featureIDMap; 130 } 131 132 /** 133 * Gets the output domain. 134 * @return The output domain. 135 */ 136 public ImmutableOutputInfo<T> getOutputIDInfo() { 137 return outputIDInfo; 138 } 139 140 /** 141 * Does this model generate probabilistic predictions. 142 * @return True if the model generates probabilistic predictions. 143 */ 144 public boolean generatesProbabilities() { 145 return generatesProbabilities; 146 } 147 148 /** 149 * Validates that this Model does in fact support the supplied output type. 150 * <p> 151 * As the output type is erased at runtime, deserialising a Model is an unchecked 152 * operation. This method allows the user to check that the deserialised model is 153 * of the appropriate type, rather than seeing if {@link Model#predict} throws a {@link ClassCastException} 154 * when called. 155 * </p> 156 * @param clazz The class object to verify the output type against. 157 * @return True if the output type is assignable to the class object type, false otherwise. 158 */ 159 public boolean validate(Class<? extends Output<?>> clazz) { 160 Set<T> domain = outputIDInfo.getDomain(); 161 boolean output = true; 162 for (T type : domain) { 163 output &= clazz.isInstance(type); 164 } 165 return output; 166 } 167 168 /** 169 * Uses the model to predict the output for a single example. 170 * <p> 171 * predict does not mutate the example. 172 * <p> 173 * Throws {@link IllegalArgumentException} if the example has no features 174 * or no feature overlap with the model. 175 * @param example the example to predict. 176 * @return the result of the prediction. 177 */ 178 public abstract Prediction<T> predict(Example<T> example); 179 180 /** 181 * Uses the model to predict the output for multiple examples. 182 * <p> 183 * Throws {@link IllegalArgumentException} if the examples have no features 184 * or no feature overlap with the model. 185 * @param examples the examples to predict. 186 * @return the results of the prediction, in the same order as the 187 * examples. 188 */ 189 public List<Prediction<T>> predict(Iterable<Example<T>> examples) { 190 return innerPredict(examples); 191 } 192 193 /** 194 * Uses the model to predict the outputs for multiple examples contained in 195 * a data set. 196 * <p> 197 * Throws {@link IllegalArgumentException} if the examples have no features 198 * or no feature overlap with the model. 199 * @param examples the data set containing the examples to predict. 200 * @return the results of the predictions, in the same order as the 201 * Dataset provides the examples. 202 */ 203 public List<Prediction<T>> predict(Dataset<T> examples) { 204 return innerPredict(examples); 205 } 206 207 /** 208 * Called by the base implementations of {@link Model#predict(Iterable)} and {@link Model#predict(Dataset)}. 209 * @param examples The examples to predict. 210 * @return The results of the predictions, in the same order as the examples. 211 */ 212 protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) { 213 List<Prediction<T>> predictions = new ArrayList<>(); 214 for (Example<T> example : examples) { 215 predictions.add(predict(example)); 216 } 217 return predictions; 218 } 219 220 /** 221 * Gets the top {@code n} features associated with this model. 222 * <p> 223 * If the model does not produce per output feature lists, it returns 224 * a map with a single element with key Model.ALL_OUTPUTS. 225 * </p> 226 * <p> 227 * If the model cannot describe it's top features then it returns {@link java.util.Collections#emptyMap}. 228 * </p> 229 * @param n the number of features to return. If this value is less than 0, 230 * all features should be returned for each class, unless the model cannot score it's features. 231 * @return a map from string outputs to an ordered list of pairs of 232 * feature names and weights associated with that feature in the model 233 */ 234 public abstract Map<String,List<Pair<String,Double>>> getTopFeatures(int n); 235 236 /** 237 * Generates an excuse for an example. 238 * <p> 239 * This attempts to explain a classification result. Generating an excuse may be quite an expensive operation. 240 * <p> 241 * This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS. 242 * <p> 243 * The optional is empty if the model does not provide excuses. 244 * @param example The input example. 245 * @return An optional excuse object. The optional is empty if this model does not provide excuses. 246 */ 247 public abstract Optional<Excuse<T>> getExcuse(Example<T> example); 248 249 /** 250 * Generates an excuse for each example. 251 * <p> 252 * This may be an expensive operation, and probably should be overridden in subclasses for performance reasons. 253 * <p> 254 * These excuses either contain per class information or an entry with key Model.ALL_OUTPUTS. 255 * <p> 256 * The optional is empty if the model does not provide excuses. 257 * @param examples An iterable of examples 258 * @return A optional list of excuses. The Optional is empty if this model does not provide excuses. 259 */ 260 public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> examples) { 261 List<Excuse<T>> excuses = new ArrayList<>(); 262 for (Example<T> e : examples) { 263 Optional<Excuse<T>> excuse = getExcuse(e); 264 if (excuse.isPresent()) { 265 excuses.add(excuse.get()); 266 } else { 267 return Optional.empty(); 268 } 269 } 270 return Optional.of(excuses); 271 } 272 273 /** 274 * Copies a model, returning a deep copy of any mutable state, and a shallow copy otherwise. 275 * @return A copy of the model. 276 */ 277 public Model<T> copy() { 278 List<ObjectMarshalledProvenance> omp = ProvenanceUtil.marshalProvenance(provenance); 279 ModelProvenance provenanceCopy = (ModelProvenance) ProvenanceUtil.unmarshalProvenance(omp); 280 return copy(name,provenanceCopy); 281 } 282 283 /** 284 * Copies a model, replacing it's provenance and name with the supplied values. 285 * <p> 286 * Used to provide the provenance removal functionality. 287 * @param newName The new name. 288 * @param newProvenance The new provenance. 289 * @return A copy of the model. 290 */ 291 protected abstract Model<T> copy(String newName, ModelProvenance newProvenance); 292 293 @Override 294 public String toString() { 295 if (name != null && !name.isEmpty()) { 296 return name + " - " + provenanceOutput; 297 } else { 298 return provenanceOutput; 299 } 300 } 301 302}