Class TensorFlowNativeModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
org.tribuo.interop.tensorflow.TensorFlowNativeModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, AutoCloseable

public final class TensorFlowNativeModel<T extends Output<T>> extends TensorFlowModel<T>
This model encapsulates a TensorFlow model running in graph mode with a single tensor output.

It accepts an FeatureConverter that converts an example's features into a set of Tensors, and an OutputConverter that converts a Tensor into a Prediction.

This model's serialized form stores the weights and is entirely self contained. If you wish to convert it into a model which uses checkpoints then call convertToCheckpointModel(java.lang.String, java.lang.String).

The model's serialVersionUID is set to the major TensorFlow version number times 100.

N.B. TensorFlow support is experimental and may change without a major version bump.

See Also:
  • Method Details

    • copy

      protected TensorFlowNativeModel<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • convertToCheckpointModel

      public TensorFlowCheckpointModel<T> convertToCheckpointModel(String checkpointDirectory, String checkpointName)
      Creates a TensorFlowCheckpointModel version of this model.
      Parameters:
      checkpointDirectory - The directory to write the checkpoint to.
      checkpointName - The name of the checkpoint files.
      Returns:
      A version of this model using a TensorFlow checkpoint to store the parameters.