Class TensorFlowModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
Type Parameters:
T - The output type.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, AutoCloseable, ProtoSerializable<org.tribuo.protos.core.ModelProto>
Direct Known Subclasses:
TensorFlowCheckpointModel, TensorFlowNativeModel

public abstract class TensorFlowModel<T extends Output<T>> extends Model<T> implements AutoCloseable
Base class for a TensorFlow model that operates on Examples.

The subclasses are package private and concern themselves with how the model is stored on disk.

N.B. TensorFlow support is experimental and may change without a major version bump.

See Also:
  • Field Details

    • batchSize

      protected int batchSize
    • outputName

      protected final String outputName
    • featureConverter

      protected final FeatureConverter featureConverter
    • outputConverter

      protected final OutputConverter<T extends Output<T>> outputConverter
    • modelGraph

      protected transient org.tensorflow.Graph modelGraph
    • session

      protected transient org.tensorflow.Session session
    • closed

      protected transient boolean closed
  • Constructor Details

    • TensorFlowModel

      protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.framework.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter)
      Builds a TFModel. The session should be initialized in the subclass constructor.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      trainedGraphDef - The graph definition.
      batchSize - The test time batch size.
      outputName - The name of the output operation.
      featureConverter - The feature converter.
      outputConverter - The output converter.
  • Method Details

    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • innerPredict

      protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples)
      Description copied from class: Model
      Called by the base implementations of Model.predict(Iterable) and Model.predict(Dataset).
      Overrides:
      innerPredict in class Model<T extends Output<T>>
      Parameters:
      examples - The examples to predict.
      Returns:
      The results of the predictions, in the same order as the examples.
    • getBatchSize

      public int getBatchSize()
      Gets the current testing batch size.
      Returns:
      The batch size.
    • setBatchSize

      public void setBatchSize(int batchSize)
      Sets a new batch size.

      Throws IllegalArgumentException if the batch size isn't positive.

      Parameters:
      batchSize - The batch size to use.
    • getTopFeatures

      public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Deep learning models don't do feature rankings. Use an Explainer.

      This method always returns the empty map.

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return.
      Returns:
      The empty map.
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Deep learning models don't do excuses. Use an Explainer.

      This method always returns Optional.empty().

      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      Optional.empty().
    • getOutputName

      public String getOutputName()
      Gets the name of the output operation.
      Returns:
      The output operation name.
    • exportModel

      public void exportModel(String path) throws IOException
      Exports this model as a SavedModelBundle, writing to the supplied directory.
      Parameters:
      path - The directory to export to.
      Throws:
      IOException - If it failed to write to the directory.
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable