Package org.tribuo.interop.tensorflow
Class TensorFlowNativeModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
org.tribuo.interop.tensorflow.TensorFlowNativeModel<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,AutoCloseable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
This model encapsulates a TensorFlow model running in graph mode with a single tensor output.
It accepts an FeatureConverter
that converts an example's features into a set of Tensor
s, and an
OutputConverter
that converts a Tensor
into a Prediction
.
This model's serialized form stores the weights and is entirely self contained.
If you wish to convert it into a model which uses checkpoints then call convertToCheckpointModel(java.lang.String, java.lang.String)
.
The model's serialVersionUID is set to the major TensorFlow version number times 100.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.Fields inherited from class org.tribuo.interop.tensorflow.TensorFlowModel
batchSize, closed, featureConverter, modelGraph, outputConverter, outputName, session
Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Method Summary
Modifier and TypeMethodDescriptionconvertToCheckpointModel
(String checkpointDirectory, String checkpointName) Creates aTensorFlowCheckpointModel
version of this model.protected TensorFlowNativeModel<T>
copy
(String newName, ModelProvenance newProvenance) Copies a model, replacing its provenance and name with the supplied values.static TensorFlowNativeModel<?>
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.org.tribuo.protos.core.ModelProto
Serializes this object to a protobuf.Methods inherited from class org.tribuo.interop.tensorflow.TensorFlowModel
close, exportModel, getBatchSize, getExcuse, getOutputName, getTopFeatures, innerPredict, predict, setBatchSize
Methods inherited from class org.tribuo.Model
castModel, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Method Details
-
deserializeFromProto
public static TensorFlowNativeModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If the protobuf could not be parsed from themessage
.
-
copy
Description copied from class:Model
Copies a model, replacing its provenance and name with the supplied values.Used to provide the provenance removal functionality.
-
convertToCheckpointModel
public TensorFlowCheckpointModel<T> convertToCheckpointModel(String checkpointDirectory, String checkpointName) Creates aTensorFlowCheckpointModel
version of this model.- Parameters:
checkpointDirectory
- The directory to write the checkpoint to.checkpointName
- The name of the checkpoint files.- Returns:
- A version of this model using a TensorFlow checkpoint to store the parameters.
-
serialize
public org.tribuo.protos.core.ModelProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.
-