Package org.tribuo

Class Model<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
Type Parameters:
T - the type of prediction produced by the model.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
Direct Known Subclasses:
AbstractSGDModel, ClassifierChainModel, DummyClassifierModel, DummyRegressionModel, EnsembleModel, ExternalModel, HdbscanModel, IndependentMultiLabelModel, KernelSVMModel, KMeansModel, KNNModel, LibLinearModel, LibSVMModel, MultinomialNaiveBayesModel, SkeletalIndependentRegressionModel, SparseModel, TensorFlowModel, TransformedModel, XGBoostModel

public abstract class Model<T extends Output<T>> extends Object implements com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
A prediction model, which is used to predict outputs for unseen instances. Model implementations must be serializable!

If two features map to the same id in the featureIDMap, then occurrences of those features will be combined at prediction time.

See Also:
  • Field Details

    • ALL_OUTPUTS

      public static final String ALL_OUTPUTS
      Used in getTopFeatures when the Model doesn't support per output feature lists.
      See Also:
    • BIAS_FEATURE

      public static final String BIAS_FEATURE
      Used to denote the bias feature in a linear model.
      See Also:
    • name

      protected String name
      The model's name.
    • provenance

      protected final ModelProvenance provenance
      The model provenance.
    • provenanceOutput

      protected final String provenanceOutput
      The cached toString of the model provenance.

      Mostly cached so it appears in the serialized output and can be read by grepping the binary.

    • featureIDMap

      protected final ImmutableFeatureMap featureIDMap
      The features this model knows about.
    • outputIDInfo

      protected final ImmutableOutputInfo<T extends Output<T>> outputIDInfo
      The outputs this model predicts.
    • generatesProbabilities

      protected final boolean generatesProbabilities
      Does this model generate probability distributions in the output.
  • Constructor Details

    • Model

      protected Model(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities)
      Constructs a new model, storing the supplied fields.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The features.
      outputIDInfo - The possible outputs.
      generatesProbabilities - Does this model emit probabilistic outputs.
  • Method Details

    • getName

      public String getName()
      Returns the model name.
      Returns:
      The model name.
    • setName

      public void setName(String name)
      Sets the model name.
      Parameters:
      name - The new model name.
    • getProvenance

      public ModelProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<T extends Output<T>>
    • getFeatureIDMap

      public ImmutableFeatureMap getFeatureIDMap()
      Gets the feature domain.
      Returns:
      The feature domain.
    • getOutputIDInfo

      public ImmutableOutputInfo<T> getOutputIDInfo()
      Gets the output domain.
      Returns:
      The output domain.
    • generatesProbabilities

      public boolean generatesProbabilities()
      Does this model generate probabilistic predictions.
      Returns:
      True if the model generates probabilistic predictions.
    • validate

      public boolean validate(Class<? extends Output<?>> clazz)
      Validates that this Model does in fact support the supplied output type.

      As the output type is erased at runtime, deserialising a Model is an unchecked operation. This method allows the user to check that the deserialised model is of the appropriate type, rather than seeing if predict(org.tribuo.Example<T>) throws a ClassCastException when called.

      Parameters:
      clazz - The class object to verify the output type against.
      Returns:
      True if the output type is assignable to the class object type, false otherwise.
    • predict

      public abstract Prediction<T> predict(Example<T> example)
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • predict

      public List<Prediction<T>> predict(Iterable<Example<T>> examples)
      Uses the model to predict the output for multiple examples.

      Throws IllegalArgumentException if the examples have no features or no feature overlap with the model.

      Parameters:
      examples - the examples to predict.
      Returns:
      the results of the prediction, in the same order as the examples.
    • predict

      public List<Prediction<T>> predict(Dataset<T> examples)
      Uses the model to predict the outputs for multiple examples contained in a data set.

      Throws IllegalArgumentException if the examples have no features or no feature overlap with the model.

      Parameters:
      examples - the data set containing the examples to predict.
      Returns:
      the results of the predictions, in the same order as the Dataset provides the examples.
    • innerPredict

      protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples)
      Called by the base implementations of predict(Iterable) and predict(Dataset).
      Parameters:
      examples - The examples to predict.
      Returns:
      The results of the predictions, in the same order as the examples.
    • getTopFeatures

      public abstract Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • getExcuse

      public abstract Optional<Excuse<T>> getExcuse(Example<T> example)
      Generates an excuse for an example.

      This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.

      This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Parameters:
      example - The input example.
      Returns:
      An optional excuse object. The optional is empty if this model does not provide excuses.
    • getExcuses

      public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> examples)
      Generates an excuse for each example.

      This may be an expensive operation, and probably should be overridden in subclasses for performance reasons.

      These excuses either contain per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Parameters:
      examples - An iterable of examples
      Returns:
      A optional list of excuses. The Optional is empty if this model does not provide excuses.
    • copy

      public Model<T> copy()
      Copies a model, returning a deep copy of any mutable state, and a shallow copy otherwise.
      Returns:
      A copy of the model.
    • copy

      protected abstract Model<T> copy(String newName, ModelProvenance newProvenance)
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • castModel

      public <U extends Output<U>> Model<U> castModel(Class<U> outputType)
      Casts the model to the specified output type, assuming it is valid. If it's not valid, throws ClassCastException.

      This method is intended for use on a deserialized model to restore it's generic type in a safe way.

      Type Parameters:
      U - The output type.
      Parameters:
      outputType - The output type to cast to.
      Returns:
      The model cast to the correct value.