Class TensorFlowNativeModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
org.tribuo.interop.tensorflow.TensorFlowNativeModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, AutoCloseable, ProtoSerializable<org.tribuo.protos.core.ModelProto>

public final class TensorFlowNativeModel<T extends Output<T>> extends TensorFlowModel<T>
This model encapsulates a TensorFlow model running in graph mode with a single tensor output.

It accepts an FeatureConverter that converts an example's features into a set of Tensors, and an OutputConverter that converts a Tensor into a Prediction.

This model's serialized form stores the weights and is entirely self contained. If you wish to convert it into a model which uses checkpoints then call convertToCheckpointModel(java.lang.String, java.lang.String).

The model's serialVersionUID is set to the major TensorFlow version number times 100.

N.B. TensorFlow support is experimental and may change without a major version bump.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Method Details

    • deserializeFromProto

      public static TensorFlowNativeModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • copy

      protected TensorFlowNativeModel<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • convertToCheckpointModel

      public TensorFlowCheckpointModel<T> convertToCheckpointModel(String checkpointDirectory, String checkpointName)
      Creates a TensorFlowCheckpointModel version of this model.
      Parameters:
      checkpointDirectory - The directory to write the checkpoint to.
      checkpointName - The name of the checkpoint files.
      Returns:
      A version of this model using a TensorFlow checkpoint to store the parameters.
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Overrides:
      serialize in class Model<T extends Output<T>>
      Returns:
      The protobuf.