Class TensorFlowUtil
java.lang.Object
org.tribuo.interop.tensorflow.TensorFlowUtil
Helper functions for working with TensorFlow.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final classA serializable tuple containing the tensor class name, the shape and the data. -
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final StringThe name of the assignment op.static final StringThe name given to the assignment operation from the placeholders.static final StringThe name of the data type.static final StringThe name of the placeholder op.static final StringThe name of the variable op. -
Method Summary
Modifier and TypeMethodDescriptionstatic voidannotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session) Annotates a graph with an extra placeholder and assign operation for each VariableV2.static voidcloseTensorCollection(Collection<org.tensorflow.Tensor> tensors) Closes a collection ofTensors.static Map<String, TensorFlowUtil.TensorTuple> extractMarshalledVariables(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array.static StringgeneratePlaceholderName(String variableName) Creates a name for a placeholder based on the supplied variable name.static voidrestoreMarshalledVariables(org.tensorflow.Session session, Map<String, TensorFlowUtil.TensorTuple> tensorMap) Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.
-
Field Details
-
VARIABLE_V2
-
ASSIGN_OP
-
ASSIGN_PLACEHOLDER
The name given to the assignment operation from the placeholders.- See Also:
-
PLACEHOLDER
-
DTYPE
-
-
Method Details
-
closeTensorCollection
Closes a collection ofTensors.- Parameters:
tensors- The collection of tensors to close.
-
annotateGraph
public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session) Annotates a graph with an extra placeholder and assign operation for each VariableV2. This allows the graph to be deserialised usingrestoreMarshalledVariables(Session, Map).This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.
Requires a session to correctly get the output type of a VariableV2.
- Parameters:
graph- The graph to annotate.session- The session to use.
-
generatePlaceholderName
-
extractMarshalledVariables
public static Map<String, TensorFlowUtil.TensorTuple> extractMarshalledVariables(org.tensorflow.Graph graph, org.tensorflow.Session session) Extracts a Map containing the name of each Tensorflow VariableV2 and the associated parameter array. This map can then be serialised to disk.- Parameters:
graph- The graph to read operations from.session- The session to read from.- Returns:
- A map containing all variable names and parameter arrays.
-
restoreMarshalledVariables
public static void restoreMarshalledVariables(org.tensorflow.Session session, Map<String, TensorFlowUtil.TensorTuple> tensorMap) Writes a map containing the name of each Tensorflow VariableV2 and the associated parameter array into the supplied session.- Parameters:
session- The session to write to.tensorMap- The parameter map to write.
-