Class TensorflowUtil
-
Field Summary
Fields -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionstatic void
annotateGraph
(org.tensorflow.Graph graph, org.tensorflow.Session session) Annotates a graph with an extra placeholder and assign operation for each VariableV2.static void
closeTensorList
(List<org.tensorflow.Tensor<?>> tensorList) Closes a list ofTensor
s.static Object
convertTensorToArray
(org.tensorflow.Tensor<?> tensor) Extracts the appropriate type of primitive array from aTensor
.static Object
convertTensorToScalar
(org.tensorflow.Tensor<?> tensor) Converts aTensor
into a scalar object, boxing the primitive types.static void
deserialise
(org.tensorflow.Session session, Map<String, Object> tensorMap) Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.static String
generatePlaceholderName
(String variableName) Creates a name for a placeholder based on the supplied variable name.static Object
newBooleanArray
(long[] shape) Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.static Object
newByteArray
(long[] shape) Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.static Object
newDoubleArray
(long[] shape) Creates a new primitive double array of up to 8 dimensions, using the supplied shape.static Object
newFloatArray
(long[] shape) Creates a new primitive float array of up to 8 dimensions, using the supplied shape.static Object
newIntArray
(long[] shape) Creates a new primitive int array of up to 8 dimensions, using the supplied shape.static Object
newLongArray
(long[] shape) Creates a new primitive long array of up to 8 dimensions, using the supplied shape.serialise
(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array.
-
Field Details
-
VARIABLE_V2
- See Also:
-
ASSIGN_OP
-
ASSIGN_PLACEHOLDER
- See Also:
-
PLACEHOLDER
- See Also:
-
DTYPE
-
-
Constructor Details
-
TensorflowUtil
public TensorflowUtil()
-
-
Method Details
-
newBooleanArray
Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A boolean array.
-
newByteArray
Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A byte array.
-
newIntArray
Creates a new primitive int array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A int array.
-
newLongArray
Creates a new primitive long array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A long array.
-
newFloatArray
Creates a new primitive float array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A float array.
-
newDoubleArray
Creates a new primitive double array of up to 8 dimensions, using the supplied shape.Does not check the shape to see if all it's elements are positive.
- Parameters:
shape
- The shape of array to create.- Returns:
- A double array.
-
closeTensorList
Closes a list ofTensor
s.- Parameters:
tensorList
- The list of tensors to close.
-
convertTensorToArray
Extracts the appropriate type of primitive array from aTensor
.Returns an object as the user doesn't know what type is in the
Tensor
.- Parameters:
tensor
- The tensor to read.- Returns:
- A primitive array.
-
convertTensorToScalar
Converts aTensor
into a scalar object, boxing the primitive types.Does not close the Tensor.
- Parameters:
tensor
- The tensor to convert.- Returns:
- A boxed scalar.
-
annotateGraph
Annotates a graph with an extra placeholder and assign operation for each VariableV2. This allows the graph to be deserialised usingdeserialise(Session, Map)
.This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.
Requires a session to correctly get the output type of a VariableV2. This isn't strictly necessary, but there aren't typed ways to get outputs in the TF version we use.
- Parameters:
graph
- The graph to annotate.session
- The session to use.
-
generatePlaceholderName
Creates a name for a placeholder based on the supplied variable name.- Parameters:
variableName
- The variable name to use as a base.- Returns:
- A name for the placeholder.
-
serialise
public static Map<String,Object> serialise(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array. This map can then be serialised to disk.- Parameters:
graph
- The graph to read operations from.session
- The session to read from.- Returns:
- A map containing all variable names and parameter arrays.
-
deserialise
Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.- Parameters:
session
- The session to write to.tensorMap
- The parameter map to write.
-