Class TensorFlowFrozenExternalModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.ExternalModel<T,TensorMap,org.tensorflow.Tensor>
org.tribuo.interop.tensorflow.TensorFlowFrozenExternalModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Closeable, Serializable, AutoCloseable, ProtoSerializable<org.tribuo.protos.core.ModelProto>

public final class TensorFlowFrozenExternalModel<T extends Output<T>> extends ExternalModel<T,TensorMap,org.tensorflow.Tensor> implements Closeable
A Tribuo wrapper around a TensorFlow frozen model.

The model's serialVersionUID is set to the major Tensorflow version number times 100.

N.B. TensorFlow support is experimental and may change without a major version bump.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Method Details

    • deserializeFromProto

      public static TensorFlowFrozenExternalModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • convertFeatures

      protected TensorMap convertFeatures(SparseVector input)
      Description copied from class: ExternalModel
      Converts from a SparseVector using the external model's indices into the ingestion format for the external model.
      Specified by:
      convertFeatures in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • convertFeaturesList

      protected TensorMap convertFeaturesList(List<SparseVector> input)
      Description copied from class: ExternalModel
      Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.
      Specified by:
      convertFeaturesList in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • externalPrediction

      protected org.tensorflow.Tensor externalPrediction(TensorMap input)
      Runs the session to make a prediction.

      Closes the input tensor after the prediction has been made.

      Specified by:
      externalPrediction in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
      Parameters:
      input - The input in the external model's format.
      Returns:
      A tensor representing the output.
    • convertOutput

      protected Prediction<T> convertOutput(org.tensorflow.Tensor output, int numValidFeatures, Example<T> example)
      Converts a tensor into a prediction. Closes the output tensor after it's been converted.
      Specified by:
      convertOutput in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
      Parameters:
      output - The output of the external model.
      numValidFeatures - The number of valid features in the input.
      example - The input example, used to construct the Prediction.
      Returns:
      A Prediction representing this tensor output.
    • convertOutput

      protected List<Prediction<T>> convertOutput(org.tensorflow.Tensor output, int[] numValidFeatures, List<Example<T>> examples)
      Converts a tensor into a prediction. Closes the output tensor after it's been converted.
      Specified by:
      convertOutput in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
      Parameters:
      output - The output of the external model.
      numValidFeatures - An array with the number of valid features in each example.
      examples - The input examples, used to construct the Predictions.
      Returns:
      A list of Prediction representing this tensor output.
    • getTopFeatures

      public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Description copied from class: Model
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • copy

      protected Model<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
      Specified by:
      close in interface Closeable
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Overrides:
      serialize in class Model<T extends Output<T>>
      Returns:
      The protobuf.
    • createTensorflowModel

      public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String filename)
      Creates a TensorflowFrozenExternalModel by loading in a frozen graph.
      Type Parameters:
      T - The type of the output.
      Parameters:
      factory - The output factory.
      featureMapping - The feature mapping between Tribuo's names and the TF integer ids.
      outputMapping - The output mapping between Tribuo's names and the TF integer ids.
      outputName - The name of the output tensor.
      featureConverter - The feature transformation function.
      outputConverter - The output transformation function.
      filename - The filename to load the graph from.
      Returns:
      The TF model wrapped in a Tribuo ExternalModel.