public abstract class TensorFlowUtil extends Object
Modifier and Type | Class and Description |
---|---|
static class |
TensorFlowUtil.TensorTuple
A serializable tuple containing the tensor class name, the shape and the data.
|
Modifier and Type | Field and Description |
---|---|
static String |
ASSIGN_OP |
static String |
ASSIGN_PLACEHOLDER |
static String |
DTYPE |
static String |
PLACEHOLDER |
static String |
VARIABLE_V2 |
Modifier and Type | Method and Description |
---|---|
static void |
annotateGraph(org.tensorflow.Graph graph,
org.tensorflow.Session session)
Annotates a graph with an extra placeholder and assign operation for each
VariableV2.
|
static void |
closeTensorCollection(Collection<org.tensorflow.Tensor> tensors)
Closes a collection of
Tensor s. |
static Map<String,TensorFlowUtil.TensorTuple> |
extractMarshalledVariables(org.tensorflow.Graph graph,
org.tensorflow.Session session)
Extracts a Map containing the name of each Tensorflow VariableV2 and the
associated parameter array.
|
static String |
generatePlaceholderName(String variableName)
Creates a name for a placeholder based on the supplied variable name.
|
static void |
restoreMarshalledVariables(org.tensorflow.Session session,
Map<String,TensorFlowUtil.TensorTuple> tensorMap)
Writes a map containing the name of each Tensorflow VariableV2 and the associated
parameter array into the supplied session.
|
public static final String VARIABLE_V2
public static final String ASSIGN_OP
public static final String ASSIGN_PLACEHOLDER
public static final String PLACEHOLDER
public static final String DTYPE
public static void closeTensorCollection(Collection<org.tensorflow.Tensor> tensors)
Tensor
s.tensors
- The collection of tensors to close.public static void annotateGraph(org.tensorflow.Graph graph, org.tensorflow.Session session)
restoreMarshalledVariables(Session, Map)
.
This operation can either be done each time the Graph is loaded before deserialise is called, or once, and the updated graphDef persisted with the Map produced by serialise.
Requires a session to correctly get the output type of a VariableV2.
graph
- The graph to annotate.session
- The session to use.public static String generatePlaceholderName(String variableName)
variableName
- The variable name to use as a base.public static Map<String,TensorFlowUtil.TensorTuple> extractMarshalledVariables(org.tensorflow.Graph graph, org.tensorflow.Session session)
graph
- The graph to read operations from.session
- The session to read from.public static void restoreMarshalledVariables(org.tensorflow.Session session, Map<String,TensorFlowUtil.TensorTuple> tensorMap)
session
- The session to write to.tensorMap
- The parameter map to write.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.