Class TensorFlowUtil

java.lang.Object
org.tribuo.interop.tensorflow.TensorFlowUtil

public abstract class TensorFlowUtil extends Object
Helper functions for working with TensorFlow.
  • Field Details

  • Method Details

    • closeTensorCollection

      public static void closeTensorCollection(Collection<org.tensorflow.Tensor> tensors)
      Closes a collection of Tensors.
      Parameters:
      tensors - The collection of tensors to close.
    • annotateGraph

      public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session)
      Annotates a graph with an extra placeholder and assign operation for each VariableV2. This allows the graph to be deserialised using restoreMarshalledVariables(Session, Map).

      This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.

      Requires a session to correctly get the output type of a VariableV2.

      Parameters:
      graph - The graph to annotate.
      session - The session to use.
    • generatePlaceholderName

      public static String generatePlaceholderName(String variableName)
      Creates a name for a placeholder based on the supplied variable name.
      Parameters:
      variableName - The variable name to use as a base.
      Returns:
      A name for the placeholder.
    • extractMarshalledVariables

      public static Map<String,TensorFlowUtil.TensorTuple> extractMarshalledVariables(org.tensorflow.Graph graph, org.tensorflow.Session session)
      Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array. This map can then be serialised to disk.
      Parameters:
      graph - The graph to read operations from.
      session - The session to read from.
      Returns:
      A map containing all variable names and parameter arrays.
    • restoreMarshalledVariables

      public static void restoreMarshalledVariables(org.tensorflow.Session session, Map<String,TensorFlowUtil.TensorTuple> tensorMap)
      Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.
      Parameters:
      session - The session to write to.
      tensorMap - The parameter map to write.