Class TensorFlowModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
- Type Parameters:
T- The output type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>,Serializable,AutoCloseable,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
TensorFlowCheckpointModel,TensorFlowNativeModel
public abstract class TensorFlowModel<T extends Output<T>>
extends Model<T>
implements AutoCloseable
Base class for a TensorFlow model that operates on
Examples.
The subclasses are package private and concern themselves with how the model is stored on disk.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected intprotected booleanprotected final FeatureConverterprotected org.tensorflow.Graphprotected final OutputConverter<T> protected final Stringprotected org.tensorflow.SessionFields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutputFields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedTensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel. -
Method Summary
Modifier and TypeMethodDescriptionvoidclose()voidexportModel(String path) Exports this model as aSavedModelBundle, writing to the supplied directory.intGets the current testing batch size.Deep learning models don't do excuses.Gets the name of the output operation.getTopFeatures(int n) Deep learning models don't do feature rankings.protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) Called by the base implementations ofModel.predict(Iterable)andModel.predict(Dataset).Uses the model to predict the output for a single example.voidsetBatchSize(int batchSize) Sets a new batch size.Methods inherited from class org.tribuo.Model
castModel, copy, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, serialize, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
batchSize
protected int batchSize -
outputName
-
featureConverter
-
outputConverter
-
modelGraph
protected transient org.tensorflow.Graph modelGraph -
session
protected transient org.tensorflow.Session session -
closed
protected transient boolean closed
-
-
Constructor Details
-
TensorFlowModel
protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel. The session should be initialized in the subclass constructor.- Parameters:
name- The model name.provenance- The model provenance.featureIDMap- The feature domain.outputIDInfo- The output domain.trainedGraphDef- The graph definition.batchSize- The test time batch size.outputName- The name of the output operation.featureConverter- The feature converter.outputConverter- The output converter.
-
-
Method Details
-
predict
Description copied from class:ModelUses the model to predict the output for a single example.predict does not mutate the example.
Throws
IllegalArgumentExceptionif the example has no features or no feature overlap with the model. -
innerPredict
Description copied from class:ModelCalled by the base implementations ofModel.predict(Iterable)andModel.predict(Dataset).- Overrides:
innerPredictin classModel<T extends Output<T>>- Parameters:
examples- The examples to predict.- Returns:
- The results of the predictions, in the same order as the examples.
-
getBatchSize
public int getBatchSize()Gets the current testing batch size.- Returns:
- The batch size.
-
setBatchSize
public void setBatchSize(int batchSize) Sets a new batch size.Throws
IllegalArgumentExceptionif the batch size isn't positive.- Parameters:
batchSize- The batch size to use.
-
getTopFeatures
Deep learning models don't do feature rankings. Use an Explainer.This method always returns the empty map.
- Specified by:
getTopFeaturesin classModel<T extends Output<T>>- Parameters:
n- the number of features to return.- Returns:
- The empty map.
-
getExcuse
-
getOutputName
Gets the name of the output operation.- Returns:
- The output operation name.
-
exportModel
Exports this model as aSavedModelBundle, writing to the supplied directory.- Parameters:
path- The directory to export to.- Throws:
IOException- If it failed to write to the directory.
-
close
public void close()- Specified by:
closein interfaceAutoCloseable
-