Package org.tribuo.interop.tensorflow
Class TensorFlowUtil
java.lang.Object
org.tribuo.interop.tensorflow.TensorFlowUtil
Helper functions for working with TensorFlow.
-
Nested Class Summary
Modifier and TypeClassDescriptionstatic final class
A serializable tuple containing the tensor class name, the shape and the data. -
Field Summary
Modifier and TypeFieldDescriptionstatic final String
The name of the assignment op.static final String
The name given to the assignment operation from the placeholders.static final String
The name of the data type.static final String
The name of the placeholder op.static final String
The name of the variable op. -
Method Summary
Modifier and TypeMethodDescriptionstatic void
annotateGraph
(org.tensorflow.Graph graph, org.tensorflow.Session session) Annotates a graph with an extra placeholder and assign operation for each VariableV2.static void
closeTensorCollection
(Collection<org.tensorflow.Tensor> tensors) Closes a collection ofTensor
s.static Map<String,
TensorFlowUtil.TensorTuple> extractMarshalledVariables
(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array.static String
generatePlaceholderName
(String variableName) Creates a name for a placeholder based on the supplied variable name.static void
restoreMarshalledVariables
(org.tensorflow.Session session, Map<String, TensorFlowUtil.TensorTuple> tensorMap) Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.
-
Field Details
-
VARIABLE_V2
The name of the variable op.- See Also:
-
ASSIGN_OP
The name of the assignment op.- See Also:
-
ASSIGN_PLACEHOLDER
The name given to the assignment operation from the placeholders.- See Also:
-
PLACEHOLDER
The name of the placeholder op.- See Also:
-
DTYPE
The name of the data type.- See Also:
-
-
Method Details
-
closeTensorCollection
Closes a collection ofTensor
s.- Parameters:
tensors
- The collection of tensors to close.
-
annotateGraph
public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session) Annotates a graph with an extra placeholder and assign operation for each VariableV2. This allows the graph to be deserialised usingrestoreMarshalledVariables(Session, Map)
.This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.
Requires a session to correctly get the output type of a VariableV2.
- Parameters:
graph
- The graph to annotate.session
- The session to use.
-
generatePlaceholderName
Creates a name for a placeholder based on the supplied variable name.- Parameters:
variableName
- The variable name to use as a base.- Returns:
- A name for the placeholder.
-
extractMarshalledVariables
public static Map<String,TensorFlowUtil.TensorTuple> extractMarshalledVariables(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array. This map can then be serialised to disk.- Parameters:
graph
- The graph to read operations from.session
- The session to read from.- Returns:
- A map containing all variable names and parameter arrays.
-
restoreMarshalledVariables
public static void restoreMarshalledVariables(org.tensorflow.Session session, Map<String, TensorFlowUtil.TensorTuple> tensorMap) Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.- Parameters:
session
- The session to write to.tensorMap
- The parameter map to write.
-