public class TensorflowUtil extends Object
Modifier and Type | Field and Description |
---|---|
static String |
ASSIGN_OP |
static String |
ASSIGN_PLACEHOLDER |
static String |
DTYPE |
static String |
PLACEHOLDER |
static String |
VARIABLE_V2 |
Constructor and Description |
---|
TensorflowUtil() |
Modifier and Type | Method and Description |
---|---|
static void |
annotateGraph(org.tensorflow.Graph graph,
org.tensorflow.Session session)
Annotates a graph with an extra placeholder and assign operation for each
VariableV2.
|
static void |
closeTensorList(List<org.tensorflow.Tensor<?>> tensorList)
Closes a list of
Tensor s. |
static Object |
convertTensorToArray(org.tensorflow.Tensor<?> tensor)
Extracts the appropriate type of primitive array from a
Tensor . |
static Object |
convertTensorToScalar(org.tensorflow.Tensor<?> tensor)
Converts a
Tensor into a scalar object, boxing the primitive types. |
static void |
deserialise(org.tensorflow.Session session,
Map<String,Object> tensorMap)
Writes a map containing the name of each Tensorflow VariableV2 and the associated
parameter array into the supplied session.
|
static String |
generatePlaceholderName(String variableName)
Creates a name for a placeholder based on the supplied variable name.
|
static Object |
newBooleanArray(long[] shape)
Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.
|
static Object |
newByteArray(long[] shape)
Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.
|
static Object |
newDoubleArray(long[] shape)
Creates a new primitive double array of up to 8 dimensions, using the supplied shape.
|
static Object |
newFloatArray(long[] shape)
Creates a new primitive float array of up to 8 dimensions, using the supplied shape.
|
static Object |
newIntArray(long[] shape)
Creates a new primitive int array of up to 8 dimensions, using the supplied shape.
|
static Object |
newLongArray(long[] shape)
Creates a new primitive long array of up to 8 dimensions, using the supplied shape.
|
static Map<String,Object> |
serialise(org.tensorflow.Graph graph,
org.tensorflow.Session session)
Extracts a Map containing the name of each Tensorflow VariableV2 and the
associated parameter array.
|
public static final String VARIABLE_V2
public static final String ASSIGN_OP
public static final String ASSIGN_PLACEHOLDER
public static final String PLACEHOLDER
public static final String DTYPE
public static Object newBooleanArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static Object newByteArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static Object newIntArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static Object newLongArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static Object newFloatArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static Object newDoubleArray(long[] shape)
Does not check the shape to see if all it's elements are positive.
shape
- The shape of array to create.public static void closeTensorList(List<org.tensorflow.Tensor<?>> tensorList)
Tensor
s.tensorList
- The list of tensors to close.public static Object convertTensorToArray(org.tensorflow.Tensor<?> tensor)
Tensor
.
Returns an object as the user doesn't know what type is in the Tensor
.
tensor
- The tensor to read.public static Object convertTensorToScalar(org.tensorflow.Tensor<?> tensor)
Tensor
into a scalar object, boxing the primitive types.
Does not close the Tensor.
tensor
- The tensor to convert.public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session)
deserialise(Session, Map)
.
This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.
Requires a session to correctly get the output type of a VariableV2. This isn't strictly necessary, but there aren't typed ways to get outputs in the TF version we use.
graph
- The graph to annotate.session
- The session to use.public static String generatePlaceholderName(String variableName)
variableName
- The variable name to use as a base.public static Map<String,Object> serialise(org.tensorflow.Graph graph, org.tensorflow.Session session)
graph
- The graph to read operations from.session
- The session to read from.public static void deserialise(org.tensorflow.Session session, Map<String,Object> tensorMap)
session
- The session to write to.tensorMap
- The parameter map to write.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.