Class TensorflowUtil

java.lang.Object
org.tribuo.interop.tensorflow.TensorflowUtil

public class TensorflowUtil extends Object
Helper functions for working with Tensorflow.
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    static final String
     
    static final String
     
    static final String
     
    static final String
     
    static final String
     
  • Constructor Summary

    Constructors
    Constructor
    Description
     
  • Method Summary

    Modifier and Type
    Method
    Description
    static void
    annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session)
    Annotates a graph with an extra placeholder and assign operation for each VariableV2.
    static void
    closeTensorList(List<org.tensorflow.Tensor<?>> tensorList)
    Closes a list of Tensors.
    static Object
    convertTensorToArray(org.tensorflow.Tensor<?> tensor)
    Extracts the appropriate type of primitive array from a Tensor.
    static Object
    convertTensorToScalar(org.tensorflow.Tensor<?> tensor)
    Converts a Tensor into a scalar object, boxing the primitive types.
    static void
    deserialise(org.tensorflow.Session session, Map<String,Object> tensorMap)
    Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.
    static String
    Creates a name for a placeholder based on the supplied variable name.
    static Object
    newBooleanArray(long[] shape)
    Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.
    static Object
    newByteArray(long[] shape)
    Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.
    static Object
    newDoubleArray(long[] shape)
    Creates a new primitive double array of up to 8 dimensions, using the supplied shape.
    static Object
    newFloatArray(long[] shape)
    Creates a new primitive float array of up to 8 dimensions, using the supplied shape.
    static Object
    newIntArray(long[] shape)
    Creates a new primitive int array of up to 8 dimensions, using the supplied shape.
    static Object
    newLongArray(long[] shape)
    Creates a new primitive long array of up to 8 dimensions, using the supplied shape.
    static Map<String,Object>
    serialise(org.tensorflow.Graph graph, org.tensorflow.Session session)
    Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Field Details

  • Constructor Details

  • Method Details

    • newBooleanArray

      public static Object newBooleanArray(long[] shape)
      Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A boolean array.
    • newByteArray

      public static Object newByteArray(long[] shape)
      Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A byte array.
    • newIntArray

      public static Object newIntArray(long[] shape)
      Creates a new primitive int array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A int array.
    • newLongArray

      public static Object newLongArray(long[] shape)
      Creates a new primitive long array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A long array.
    • newFloatArray

      public static Object newFloatArray(long[] shape)
      Creates a new primitive float array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A float array.
    • newDoubleArray

      public static Object newDoubleArray(long[] shape)
      Creates a new primitive double array of up to 8 dimensions, using the supplied shape.

      Does not check the shape to see if all it's elements are positive.

      Parameters:
      shape - The shape of array to create.
      Returns:
      A double array.
    • closeTensorList

      public static void closeTensorList(List<org.tensorflow.Tensor<?>> tensorList)
      Closes a list of Tensors.
      Parameters:
      tensorList - The list of tensors to close.
    • convertTensorToArray

      public static Object convertTensorToArray(org.tensorflow.Tensor<?> tensor)
      Extracts the appropriate type of primitive array from a Tensor.

      Returns an object as the user doesn't know what type is in the Tensor.

      Parameters:
      tensor - The tensor to read.
      Returns:
      A primitive array.
    • convertTensorToScalar

      public static Object convertTensorToScalar(org.tensorflow.Tensor<?> tensor)
      Converts a Tensor into a scalar object, boxing the primitive types.

      Does not close the Tensor.

      Parameters:
      tensor - The tensor to convert.
      Returns:
      A boxed scalar.
    • annotateGraph

      public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session)
      Annotates a graph with an extra placeholder and assign operation for each VariableV2. This allows the graph to be deserialised using deserialise(Session, Map).

      This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.

      Requires a session to correctly get the output type of a VariableV2. This isn't strictly necessary, but there aren't typed ways to get outputs in the TF version we use.

      Parameters:
      graph - The graph to annotate.
      session - The session to use.
    • generatePlaceholderName

      public static String generatePlaceholderName(String variableName)
      Creates a name for a placeholder based on the supplied variable name.
      Parameters:
      variableName - The variable name to use as a base.
      Returns:
      A name for the placeholder.
    • serialise

      public static Map<String,Object> serialise(org.tensorflow.Graph graph, org.tensorflow.Session session)
      Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array. This map can then be serialised to disk.
      Parameters:
      graph - The graph to read operations from.
      session - The session to read from.
      Returns:
      A map containing all variable names and parameter arrays.
    • deserialise

      public static void deserialise(org.tensorflow.Session session, Map<String,Object> tensorMap)
      Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.
      Parameters:
      session - The session to write to.
      tensorMap - The parameter map to write.