Class TensorFlowModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.tensorflow.TensorFlowModel<T>
- Type Parameters:
- T- The output type.
- All Implemented Interfaces:
- com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>,- Serializable,- AutoCloseable
- Direct Known Subclasses:
- TensorFlowCheckpointModel,- TensorFlowNativeModel
public abstract class TensorFlowModel<T extends Output<T>>
extends Model<T>
implements AutoCloseable
Base class for a TensorFlow model that operates on 
Examples.
 The subclasses are package private and concern themselves with how the model is stored on disk.
N.B. TensorFlow support is experimental and may change without a major version bump.
- See Also:
- 
Field SummaryFieldsModifier and TypeFieldDescriptionprotected intprotected booleanprotected final FeatureConverterprotected org.tensorflow.Graphprotected final OutputConverter<T> protected final Stringprotected org.tensorflow.SessionFields inherited from class org.tribuo.ModelALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
- 
Constructor SummaryConstructorsModifierConstructorDescriptionprotectedTensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.framework.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel.
- 
Method SummaryModifier and TypeMethodDescriptionvoidclose()voidexportModel(String path) Exports this model as aSavedModelBundle, writing to the supplied directory.intGets the current testing batch size.Deep learning models don't do excuses.Gets the name of the output operation.getTopFeatures(int n) Deep learning models don't do feature rankings.protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) Called by the base implementations ofModel.predict(Iterable)andModel.predict(Dataset).Uses the model to predict the output for a single example.voidsetBatchSize(int batchSize) Sets a new batch size.Methods inherited from class org.tribuo.ModelcastModel, copy, copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
- 
Field Details- 
batchSizeprotected int batchSize
- 
outputName
- 
featureConverter
- 
outputConverter
- 
modelGraphprotected transient org.tensorflow.Graph modelGraph
- 
sessionprotected transient org.tensorflow.Session session
- 
closedprotected transient boolean closed
 
- 
- 
Constructor Details- 
TensorFlowModelprotected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, org.tensorflow.proto.framework.GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) Builds a TFModel. The session should be initialized in the subclass constructor.- Parameters:
- name- The model name.
- provenance- The model provenance.
- featureIDMap- The feature domain.
- outputIDInfo- The output domain.
- trainedGraphDef- The graph definition.
- batchSize- The test time batch size.
- outputName- The name of the output operation.
- featureConverter- The feature converter.
- outputConverter- The output converter.
 
 
- 
- 
Method Details- 
predictDescription copied from class:ModelUses the model to predict the output for a single example.predict does not mutate the example. Throws IllegalArgumentExceptionif the example has no features or no feature overlap with the model.
- 
innerPredictDescription copied from class:ModelCalled by the base implementations ofModel.predict(Iterable)andModel.predict(Dataset).- Overrides:
- innerPredictin class- Model<T extends Output<T>>
- Parameters:
- examples- The examples to predict.
- Returns:
- The results of the predictions, in the same order as the examples.
 
- 
getBatchSizepublic int getBatchSize()Gets the current testing batch size.- Returns:
- The batch size.
 
- 
setBatchSizepublic void setBatchSize(int batchSize) Sets a new batch size.Throws IllegalArgumentExceptionif the batch size isn't positive.- Parameters:
- batchSize- The batch size to use.
 
- 
getTopFeaturesDeep learning models don't do feature rankings. Use an Explainer.This method always returns the empty map. - Specified by:
- getTopFeaturesin class- Model<T extends Output<T>>
- Parameters:
- n- the number of features to return.
- Returns:
- The empty map.
 
- 
getExcuse
- 
getOutputNameGets the name of the output operation.- Returns:
- The output operation name.
 
- 
exportModelExports this model as aSavedModelBundle, writing to the supplied directory.- Parameters:
- path- The directory to export to.
- Throws:
- IOException- If it failed to write to the directory.
 
- 
closepublic void close()- Specified by:
- closein interface- AutoCloseable
 
 
-