Package org.tribuo.sequence
Class SequenceModel<T extends Output<T>>
java.lang.Object
org.tribuo.sequence.SequenceModel<T>
- Type Parameters:
T
- the type of the outputs used to train the model.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,ProtoSerializable<org.tribuo.protos.core.SequenceModelProto>
- Direct Known Subclasses:
ConfidencePredictingSequenceModel
,IndependentSequenceModel
,TensorFlowSequenceModel
,ViterbiModel
public abstract class SequenceModel<T extends Output<T>>
extends Object
implements ProtoSerializable<org.tribuo.protos.core.SequenceModelProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
A prediction model, which is used to predict outputs for unseen instances.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionprotected final ImmutableFeatureMap
The feature domain.protected String
The model name.protected final ImmutableOutputInfo<T>
The output domain.protected final String
The toString of the model provenance.Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ConstructorDescriptionSequenceModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap) Builds a SequenceModel. -
Method Summary
Modifier and TypeMethodDescription<U extends Output<U>>
SequenceModel<U>Casts the model to the specified output type, assuming it is valid.protected ModelDataCarrier<T>
Constructs the data carrier for serialization.static SequenceModel<?>
deserialize
(org.tribuo.protos.core.SequenceModelProto proto) Deserializes the model from the supplied protobuf.static SequenceModel<?>
deserializeFromFile
(Path path) Reads an instance ofSequenceModelProto
from the supplied path and deserializes it.static SequenceModel<?>
Reads an instance ofSequenceModelProto
from the supplied input stream and deserializes it.Gets the feature domain.getName()
Gets the model name.Gets the output domain.getTopFeatures
(int n) Gets the topn
features associated with this model.List<List<Prediction<T>>>
predict
(Iterable<SequenceExample<T>> examples) Uses the model to predict the output for multiple examples.List<List<Prediction<T>>>
predict
(SequenceDataset<T> examples) Uses the model to predict the labels for multiple examples contained in a data set.abstract List<Prediction<T>>
predict
(SequenceExample<T> example) Uses the model to predict the output for a single example.org.tribuo.protos.core.SequenceModelProto
Serializes this object to a protobuf.void
serializeToFile
(Path path) Serializes this model to aSequenceModelProto
and writes it to the supplied path.void
serializeToStream
(OutputStream stream) Serializes this model to aSequenceModelProto
and writes it to the supplied output stream.void
Sets the model name.toMaxLabels
(List<Prediction<T>> predictions) Extracts a list of the predicted outputs from the list of prediction objects.toString()
boolean
Validates that this Model does in fact support the supplied output type.
-
Field Details
-
name
The model name. -
provenanceOutput
The toString of the model provenance. -
featureIDMap
The feature domain. -
outputIDMap
The output domain.
-
-
Constructor Details
-
SequenceModel
public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap) Builds a SequenceModel.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDMap
- The output domain.
-
-
Method Details
-
validate
Validates that this Model does in fact support the supplied output type.As the output type is erased at runtime, deserialising a Model is an unchecked operation. This method allows the user to check that the deserialised model is of the appropriate type, rather than seeing if
predict(org.tribuo.sequence.SequenceExample<T>)
throws aClassCastException
when called.- Parameters:
clazz
- The class object to verify the output type against.- Returns:
- True if the output type is assignable to the class object type, false otherwise.
-
getName
Gets the model name.- Returns:
- The model name.
-
setName
Sets the model name.- Parameters:
name
- The model name.
-
getProvenance
-
toString
-
getFeatureIDMap
Gets the feature domain.- Returns:
- The feature domain.
-
getOutputIDInfo
Gets the output domain.- Returns:
- The output domain.
-
predict
Uses the model to predict the output for a single example.- Parameters:
example
- the example to predict.- Returns:
- the result of the prediction.
-
predict
Uses the model to predict the output for multiple examples.- Parameters:
examples
- the examples to predict.- Returns:
- the results of the prediction, in the same order as the examples.
-
predict
Uses the model to predict the labels for multiple examples contained in a data set.- Parameters:
examples
- the data set containing the examples to predict.- Returns:
- the results of the predictions, in the same order as the data set generates the example.
-
serialize
public org.tribuo.protos.core.SequenceModelProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.- Specified by:
serialize
in interfaceProtoSerializable<T extends Output<T>>
- Returns:
- The protobuf.
-
serializeToFile
Serializes this model to aSequenceModelProto
and writes it to the supplied path.- Parameters:
path
- The path to write to.- Throws:
IOException
- If the path could not be written to.
-
serializeToStream
Serializes this model to aSequenceModelProto
and writes it to the supplied output stream.Does not close the stream.
- Parameters:
stream
- The output stream to write to.- Throws:
IOException
- If the stream could not be written to.
-
deserialize
Deserializes the model from the supplied protobuf.- Parameters:
proto
- The protobuf to deserialize.- Returns:
- The model.
-
deserializeFromFile
Reads an instance ofSequenceModelProto
from the supplied path and deserializes it.- Parameters:
path
- The path to read.- Returns:
- The deserialized model.
- Throws:
IOException
- If the path could not be read from, or the parsing failed.
-
deserializeFromStream
Reads an instance ofSequenceModelProto
from the supplied input stream and deserializes it.Does not close the stream.
- Parameters:
is
- The input stream to read.- Returns:
- The deserialized model.
- Throws:
IOException
- If the stream could not be read from, or the parsing failed.
-
createDataCarrier
Constructs the data carrier for serialization.- Returns:
- The serialization data carrier.
-
getTopFeatures
public abstract Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String, getTopFeaturesDouble>>> (int n) Gets the topn
features associated with this model.If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns
Collections.emptyMap()
.- Parameters:
n
- the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score its features.- Returns:
- a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
-
toMaxLabels
Extracts a list of the predicted outputs from the list of prediction objects.- Type Parameters:
T
- The prediction type.- Parameters:
predictions
- The predictions.- Returns:
- A list of predicted outputs.
-
castModel
Casts the model to the specified output type, assuming it is valid. If it's not valid, throwsClassCastException
.This method is intended for use on a deserialized model to restore its generic type in a safe way.
- Type Parameters:
U
- The output type.- Parameters:
outputType
- The output type to cast to.- Returns:
- The model cast to the correct value.
-