public final class TensorFlowSavedModelExternalModel<T extends Output<T>> extends ExternalModel<T,TensorMap,TensorMap> implements Closeable
The model's serialVersionUID is set to the major TensorFlow version number times 100.
N.B. TensorFlow support is experimental and may change without a major version bump.
DEFAULT_BATCH_SIZE, featureBackwardMapping, featureForwardMapping
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
void |
close() |
protected TensorMap |
convertFeatures(SparseVector input)
Converts from a SparseVector using the external model's indices into
the ingestion format for the external model.
|
protected TensorMap |
convertFeaturesList(List<SparseVector> input)
Converts from a list of SparseVector using the external model's indices
into the ingestion format for the external model.
|
protected List<Prediction<T>> |
convertOutput(TensorMap output,
int[] numValidFeatures,
List<Example<T>> examples)
Converts a tensor into a prediction.
|
protected Prediction<T> |
convertOutput(TensorMap output,
int numValidFeatures,
Example<T> example)
Converts a tensor into a prediction.
|
protected Model<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
static <T extends Output<T>> |
createTensorflowModel(OutputFactory<T> factory,
Map<String,Integer> featureMapping,
Map<T,Integer> outputMapping,
String outputName,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
String bundleDirectory)
Creates a TensorflowSavedModelExternalModel by loading in a
SavedModelBundle . |
protected TensorMap |
externalPrediction(TensorMap input)
Runs the session to make a prediction.
|
Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
createFeatureMap, createOutputInfo, getBatchSize, getExcuse, innerPredict, predict, setBatchSize
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
protected TensorMap convertFeatures(SparseVector input)
ExternalModel
convertFeatures
in class ExternalModel<T extends Output<T>,TensorMap,TensorMap>
input
- The features using external indices.protected TensorMap convertFeaturesList(List<SparseVector> input)
ExternalModel
convertFeaturesList
in class ExternalModel<T extends Output<T>,TensorMap,TensorMap>
input
- The features using external indices.protected TensorMap externalPrediction(TensorMap input)
Closes the input tensor after the prediction has been made.
externalPrediction
in class ExternalModel<T extends Output<T>,TensorMap,TensorMap>
input
- The input in the external model's format.protected Prediction<T> convertOutput(TensorMap output, int numValidFeatures, Example<T> example)
convertOutput
in class ExternalModel<T extends Output<T>,TensorMap,TensorMap>
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.Prediction
representing this tensor output.protected List<Prediction<T>> convertOutput(TensorMap output, int[] numValidFeatures, List<Example<T>> examples)
convertOutput
in class ExternalModel<T extends Output<T>,TensorMap,TensorMap>
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.Prediction
representing this tensor output.public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
Model
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
getTopFeatures
in class Model<T extends Output<T>>
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.protected Model<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
public void close()
close
in interface Closeable
close
in interface AutoCloseable
public static <T extends Output<T>> TensorFlowSavedModelExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String bundleDirectory)
SavedModelBundle
.
Throws IllegalArgumentException
if the model bundle could not be loaded.
T
- The type of the output.factory
- The output factory.featureMapping
- The feature mapping between Tribuo's names and the TF integer ids.outputMapping
- The output mapping between Tribuo's names and the TF integer ids.outputName
- The name of the output tensor.featureConverter
- The feature transformation function.outputConverter
- The output transformation function.bundleDirectory
- The path to load the saved model bundle from.ExternalModel
.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.