Package org.tribuo.interop.tensorflow
Class TensorFlowModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
- Type Parameters:
T
- The output type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,AutoCloseable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
TensorFlowCheckpointModel
,TensorFlowNativeModel
public abstract class TensorFlowModel<T extends Output<T>>
extends Model<T>
implements AutoCloseable
Base class for a TensorFlow model that operates on
Example
s.
The subclasses are package private and concern themselves with how the model is stored on disk.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionprotected int
protected boolean
protected final FeatureConverter
protected org.tensorflow.Graph
protected final OutputConverter<T>
protected final String
protected org.tensorflow.Session
Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ModifierConstructorDescriptionprotected
TensorFlowModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.framework.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel. -
Method Summary
Modifier and TypeMethodDescriptionvoid
close()
void
exportModel
(String path) Exports this model as aSavedModelBundle
, writing to the supplied directory.int
Gets the current testing batch size.Deep learning models don't do excuses.Gets the name of the output operation.getTopFeatures
(int n) Deep learning models don't do feature rankings.protected List<Prediction<T>>
innerPredict
(Iterable<Example<T>> examples) Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.Uses the model to predict the output for a single example.void
setBatchSize
(int batchSize) Sets a new batch size.Methods inherited from class org.tribuo.Model
castModel, copy, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, serialize, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
batchSize
protected int batchSize -
outputName
-
featureConverter
-
outputConverter
-
modelGraph
protected transient org.tensorflow.Graph modelGraph -
session
protected transient org.tensorflow.Session session -
closed
protected transient boolean closed
-
-
Constructor Details
-
TensorFlowModel
protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.framework.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel. The session should be initialized in the subclass constructor.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.trainedGraphDef
- The graph definition.batchSize
- The test time batch size.outputName
- The name of the output operation.featureConverter
- The feature converter.outputConverter
- The output converter.
-
-
Method Details
-
predict
Description copied from class:Model
Uses the model to predict the output for a single example.predict does not mutate the example.
Throws
IllegalArgumentException
if the example has no features or no feature overlap with the model. -
innerPredict
Description copied from class:Model
Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.- Overrides:
innerPredict
in classModel<T extends Output<T>>
- Parameters:
examples
- The examples to predict.- Returns:
- The results of the predictions, in the same order as the examples.
-
getBatchSize
public int getBatchSize()Gets the current testing batch size.- Returns:
- The batch size.
-
setBatchSize
public void setBatchSize(int batchSize) Sets a new batch size.Throws
IllegalArgumentException
if the batch size isn't positive.- Parameters:
batchSize
- The batch size to use.
-
getTopFeatures
Deep learning models don't do feature rankings. Use an Explainer.This method always returns the empty map.
- Specified by:
getTopFeatures
in classModel<T extends Output<T>>
- Parameters:
n
- the number of features to return.- Returns:
- The empty map.
-
getExcuse
Deep learning models don't do excuses. Use an Explainer.This method always returns
Optional.empty()
. -
getOutputName
Gets the name of the output operation.- Returns:
- The output operation name.
-
exportModel
Exports this model as aSavedModelBundle
, writing to the supplied directory.- Parameters:
path
- The directory to export to.- Throws:
IOException
- If it failed to write to the directory.
-
close
public void close()- Specified by:
close
in interfaceAutoCloseable
-