Class TensorFlowSequenceModel<T extends Output<T>>
java.lang.Object
org.tribuo.sequence.SequenceModel<T>
org.tribuo.interop.tensorflow.sequence.TensorFlowSequenceModel<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,AutoCloseable
,ProtoSerializable<org.tribuo.protos.core.SequenceModelProto>
public class TensorFlowSequenceModel<T extends Output<T>>
extends SequenceModel<T>
implements AutoCloseable
A TensorFlow model which implements SequenceModel, suitable for use in sequential prediction tasks.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.protected final SequenceFeatureConverter
protected final SequenceOutputConverter<T>
protected final String
Fields inherited from class org.tribuo.sequence.SequenceModel
featureIDMap, name, outputIDMap, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Method Summary
Modifier and TypeMethodDescriptionvoid
close()
Close the session and graph if they exist.static TensorFlowSequenceModel<?>
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.getTopFeatures
(int i) Returns an empty map, as the top features are not well defined for most TensorFlow models.List<Prediction<T>>
predict
(SequenceExample<T> example) Uses the model to predict the output for a single example.org.tribuo.protos.core.SequenceModelProto
Serializes this object to a protobuf.Methods inherited from class org.tribuo.sequence.SequenceModel
castModel, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, serializeToFile, serializeToStream, setName, toMaxLabels, toString, validate
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
featureConverter
-
outputConverter
-
predictOp
-
-
Method Details
-
deserializeFromProto
public static TensorFlowSequenceModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If the protobuf could not be parsed from themessage
.
-
predict
Description copied from class:SequenceModel
Uses the model to predict the output for a single example.- Specified by:
predict
in classSequenceModel<T extends Output<T>>
- Parameters:
example
- the example to predict.- Returns:
- the result of the prediction.
-
getTopFeatures
Returns an empty map, as the top features are not well defined for most TensorFlow models.- Specified by:
getTopFeatures
in classSequenceModel<T extends Output<T>>
- Parameters:
i
- the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score its features.- Returns:
- a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
-
close
public void close()Close the session and graph if they exist.- Specified by:
close
in interfaceAutoCloseable
-
serialize
public org.tribuo.protos.core.SequenceModelProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.- Specified by:
serialize
in interfaceProtoSerializable<T extends Output<T>>
- Overrides:
serialize
in classSequenceModel<T extends Output<T>>
- Returns:
- The protobuf.
-