Class SequenceModel<T extends Output<T>>

java.lang.Object
org.tribuo.sequence.SequenceModel<T>
Type Parameters:
T - the type of the outputs used to train the model.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ProtoSerializable<org.tribuo.protos.core.SequenceModelProto>
Direct Known Subclasses:
ConfidencePredictingSequenceModel, IndependentSequenceModel, TensorFlowSequenceModel, ViterbiModel

public abstract class SequenceModel<T extends Output<T>> extends Object implements ProtoSerializable<org.tribuo.protos.core.SequenceModelProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
A prediction model, which is used to predict outputs for unseen instances.
See Also:
  • Field Details

    • name

      protected String name
      The model name.
    • provenanceOutput

      protected final String provenanceOutput
      The toString of the model provenance.
    • featureIDMap

      protected final ImmutableFeatureMap featureIDMap
      The feature domain.
    • outputIDMap

      protected final ImmutableOutputInfo<T extends Output<T>> outputIDMap
      The output domain.
  • Constructor Details

    • SequenceModel

      public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap)
      Builds a SequenceModel.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDMap - The output domain.
  • Method Details

    • validate

      public boolean validate(Class<? extends Output<?>> clazz)
      Validates that this Model does in fact support the supplied output type.

      As the output type is erased at runtime, deserialising a Model is an unchecked operation. This method allows the user to check that the deserialised model is of the appropriate type, rather than seeing if predict(org.tribuo.sequence.SequenceExample<T>) throws a ClassCastException when called.

      Parameters:
      clazz - The class object to verify the output type against.
      Returns:
      True if the output type is assignable to the class object type, false otherwise.
    • getName

      public String getName()
      Gets the model name.
      Returns:
      The model name.
    • setName

      public void setName(String name)
      Sets the model name.
      Parameters:
      name - The model name.
    • getProvenance

      public ModelProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<T extends Output<T>>
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getFeatureIDMap

      public ImmutableFeatureMap getFeatureIDMap()
      Gets the feature domain.
      Returns:
      The feature domain.
    • getOutputIDInfo

      public ImmutableOutputInfo<T> getOutputIDInfo()
      Gets the output domain.
      Returns:
      The output domain.
    • predict

      public abstract List<Prediction<T>> predict(SequenceExample<T> example)
      Uses the model to predict the output for a single example.
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • predict

      public List<List<Prediction<T>>> predict(Iterable<SequenceExample<T>> examples)
      Uses the model to predict the output for multiple examples.
      Parameters:
      examples - the examples to predict.
      Returns:
      the results of the prediction, in the same order as the examples.
    • predict

      public List<List<Prediction<T>>> predict(SequenceDataset<T> examples)
      Uses the model to predict the labels for multiple examples contained in a data set.
      Parameters:
      examples - the data set containing the examples to predict.
      Returns:
      the results of the predictions, in the same order as the data set generates the example.
    • serialize

      public org.tribuo.protos.core.SequenceModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Returns:
      The protobuf.
    • serializeToFile

      public void serializeToFile(Path path) throws IOException
      Serializes this model to a SequenceModelProto and writes it to the supplied path.
      Parameters:
      path - The path to write to.
      Throws:
      IOException - If the path could not be written to.
    • serializeToStream

      public void serializeToStream(OutputStream stream) throws IOException
      Serializes this model to a SequenceModelProto and writes it to the supplied output stream.

      Does not close the stream.

      Parameters:
      stream - The output stream to write to.
      Throws:
      IOException - If the stream could not be written to.
    • deserialize

      public static SequenceModel<?> deserialize(org.tribuo.protos.core.SequenceModelProto proto)
      Deserializes the model from the supplied protobuf.
      Parameters:
      proto - The protobuf to deserialize.
      Returns:
      The model.
    • deserializeFromFile

      public static SequenceModel<?> deserializeFromFile(Path path) throws IOException
      Reads an instance of SequenceModelProto from the supplied path and deserializes it.
      Parameters:
      path - The path to read.
      Returns:
      The deserialized model.
      Throws:
      IOException - If the path could not be read from, or the parsing failed.
    • deserializeFromStream

      public static SequenceModel<?> deserializeFromStream(InputStream is) throws IOException
      Reads an instance of SequenceModelProto from the supplied input stream and deserializes it.

      Does not close the stream.

      Parameters:
      is - The input stream to read.
      Returns:
      The deserialized model.
      Throws:
      IOException - If the stream could not be read from, or the parsing failed.
    • createDataCarrier

      protected ModelDataCarrier<T> createDataCarrier()
      Constructs the data carrier for serialization.
      Returns:
      The serialization data carrier.
    • getTopFeatures

      public abstract Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score its features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • toMaxLabels

      public static <T extends Output<T>> List<T> toMaxLabels(List<Prediction<T>> predictions)
      Extracts a list of the predicted outputs from the list of prediction objects.
      Type Parameters:
      T - The prediction type.
      Parameters:
      predictions - The predictions.
      Returns:
      A list of predicted outputs.
    • castModel

      public <U extends Output<U>> SequenceModel<U> castModel(Class<U> outputType)
      Casts the model to the specified output type, assuming it is valid. If it's not valid, throws ClassCastException.

      This method is intended for use on a deserialized model to restore its generic type in a safe way.

      Type Parameters:
      U - The output type.
      Parameters:
      outputType - The output type to cast to.
      Returns:
      The model cast to the correct value.