Class TensorFlowCheckpointModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
org.tribuo.interop.tensorflow.TensorFlowCheckpointModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Closeable, Serializable, AutoCloseable, ProtoSerializable<org.tribuo.protos.core.ModelProto>

public final class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowModel<T> implements Closeable
This model encapsulates a simple model with an input feed dict, and produces a single output tensor.

If the checkpoint is not available on construction or after deserialisation then the model is uninitialised. Models can be initialised by calling initialize() after calling setCheckpointDirectory(java.lang.String) and setCheckpointName(java.lang.String) with the right directory and name respectively.

This model's serialized form stores the weights in the specified model checkpoint directory. If you wish to convert it into a model which stores the weights inside the model then call convertToNativeModel().

It accepts an FeatureConverter that converts an example's features into a TensorMap, and an OutputConverter that converts a Tensor into a Prediction.

The model's serialVersionUID is set to the major TensorFlow version number times 100.

N.B. TensorFlow support is experimental and may change without a major version bump.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Method Details

    • deserializeFromProto

      public static TensorFlowCheckpointModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • isInitialized

      public boolean isInitialized()
      Is this model initialized?
      Returns:
      True if the model is ready to make predictions.
    • initialize

      public final void initialize()
      Initializes the model.

      This call closes the old session (if it exists) and creates a fresh session from the current checkpoint path.

      Throws TensorFlowException if it failed to read the checkpoint.

    • setCheckpointDirectory

      public void setCheckpointDirectory(String newCheckpointDirectory)
      Sets the checkpoint directory.

      The model likely needs re-initializing after this call.

      Parameters:
      newCheckpointDirectory - The new checkpoint directory.
    • getCheckpointDirectory

      public String getCheckpointDirectory()
      Gets the checkpoint directory this model loads from.
      Returns:
      The checkpoint directory.
    • setCheckpointName

      public void setCheckpointName(String newCheckpointName)
      Sets the checkpoint name.

      The model likely needs re-initializing after this call.

      Parameters:
      newCheckpointName - The new checkpoint name.
    • getCheckpointName

      public String getCheckpointName()
      Gets the checkpoint name this model loads from.
      Returns:
      The checkpoint name.
    • convertToNativeModel

      public TensorFlowNativeModel<T> convertToNativeModel()
      Creates a TensorFlowNativeModel version of this model.
      Returns:
      A version of this model using Tribuo's native serialization mechanism.
    • copy

      protected TensorFlowCheckpointModel<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Overrides:
      serialize in class Model<T extends Output<T>>
      Returns:
      The protobuf.