Class TensorFlowCheckpointModel<T extends Output<T>>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Closeable
,Serializable
,AutoCloseable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
If the checkpoint is not available on construction or after deserialisation then the
model is uninitialised. Models can be initialised by calling initialize()
after calling
setCheckpointDirectory(java.lang.String)
and setCheckpointName(java.lang.String)
with the right directory and name
respectively.
This model's serialized form stores the weights in the specified model checkpoint directory.
If you wish to convert it into a model which stores the weights inside the model then
call convertToNativeModel()
.
It accepts an FeatureConverter
that converts an example's features into a TensorMap
, and an
OutputConverter
that converts a Tensor
into a Prediction
.
The model's serialVersionUID is set to the major TensorFlow version number times 100.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.Fields inherited from class org.tribuo.interop.tensorflow.TensorFlowModel
batchSize, closed, featureConverter, modelGraph, outputConverter, outputName, session
Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Method Summary
Modifier and TypeMethodDescriptionCreates aTensorFlowNativeModel
version of this model.protected TensorFlowCheckpointModel<T>
copy
(String newName, ModelProvenance newProvenance) Copies a model, replacing its provenance and name with the supplied values.static TensorFlowCheckpointModel<?>
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.Gets the checkpoint directory this model loads from.Gets the checkpoint name this model loads from.final void
Initializes the model.boolean
Is this model initialized?org.tribuo.protos.core.ModelProto
Serializes this object to a protobuf.void
setCheckpointDirectory
(String newCheckpointDirectory) Sets the checkpoint directory.void
setCheckpointName
(String newCheckpointName) Sets the checkpoint name.Methods inherited from class org.tribuo.interop.tensorflow.TensorFlowModel
close, exportModel, getBatchSize, getExcuse, getOutputName, getTopFeatures, innerPredict, predict, setBatchSize
Methods inherited from class org.tribuo.Model
castModel, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Method Details
-
deserializeFromProto
public static TensorFlowCheckpointModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If the protobuf could not be parsed from themessage
.
-
isInitialized
public boolean isInitialized()Is this model initialized?- Returns:
- True if the model is ready to make predictions.
-
initialize
public final void initialize()Initializes the model.This call closes the old session (if it exists) and creates a fresh session from the current checkpoint path.
Throws
TensorFlowException
if it failed to read the checkpoint. -
setCheckpointDirectory
Sets the checkpoint directory.The model likely needs re-initializing after this call.
- Parameters:
newCheckpointDirectory
- The new checkpoint directory.
-
getCheckpointDirectory
Gets the checkpoint directory this model loads from.- Returns:
- The checkpoint directory.
-
setCheckpointName
Sets the checkpoint name.The model likely needs re-initializing after this call.
- Parameters:
newCheckpointName
- The new checkpoint name.
-
getCheckpointName
Gets the checkpoint name this model loads from.- Returns:
- The checkpoint name.
-
convertToNativeModel
Creates aTensorFlowNativeModel
version of this model.- Returns:
- A version of this model using Tribuo's native serialization mechanism.
-
copy
Description copied from class:Model
Copies a model, replacing its provenance and name with the supplied values.Used to provide the provenance removal functionality.
-
serialize
public org.tribuo.protos.core.ModelProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.
-