public final class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowModel<T> implements Closeable
If the checkpoint is not available on construction or after deserialisation then the
model is uninitialised. Models can be initialised by calling initialize()
after calling
setCheckpointDirectory(java.lang.String)
and setCheckpointName(java.lang.String)
with the right directory and name
respectively.
This model's serialized form stores the weights in the specified model checkpoint directory.
If you wish to convert it into a model which stores the weights inside the model then
call convertToNativeModel()
.
It accepts an FeatureConverter
that converts an example's features into a TensorMap
, and an
OutputConverter
that converts a Tensor
into a Prediction
.
The model's serialVersionUID is set to the major TensorFlow version number times 100.
N.B. TensorFlow support is experimental and may change without a major version bump.
batchSize, closed, featureConverter, modelGraph, outputConverter, outputName, session
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
TensorFlowNativeModel<T> |
convertToNativeModel()
Creates a
TensorFlowNativeModel version of this model. |
protected TensorFlowCheckpointModel<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
String |
getCheckpointDirectory()
Gets the checkpoint directory this model loads from.
|
String |
getCheckpointName()
Gets the checkpoint name this model loads from.
|
void |
initialize()
Initializes the model.
|
boolean |
isInitialized()
Is this model initialized?
|
void |
setCheckpointDirectory(String newCheckpointDirectory)
Sets the checkpoint directory.
|
void |
setCheckpointName(String newCheckpointName)
Sets the checkpoint name.
|
close, exportModel, getBatchSize, getExcuse, getOutputName, getTopFeatures, innerPredict, predict, setBatchSize
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
public boolean isInitialized()
public final void initialize()
This call closes the old session (if it exists) and creates a fresh session from the current checkpoint path.
Throws TensorFlowException
if it failed to read the checkpoint.
public void setCheckpointDirectory(String newCheckpointDirectory)
The model likely needs re-initializing after this call.
newCheckpointDirectory
- The new checkpoint directory.public String getCheckpointDirectory()
public void setCheckpointName(String newCheckpointName)
The model likely needs re-initializing after this call.
newCheckpointName
- The new checkpoint name.public String getCheckpointName()
public TensorFlowNativeModel<T> convertToNativeModel()
TensorFlowNativeModel
version of this model.protected TensorFlowCheckpointModel<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.