public final class TensorFlowFrozenExternalModel<T extends Output<T>> extends ExternalModel<T,TensorMap,org.tensorflow.Tensor> implements Closeable
The model's serialVersionUID is set to the major Tensorflow version number times 100.
N.B. TensorFlow support is experimental and may change without a major version bump.
DEFAULT_BATCH_SIZE, featureBackwardMapping, featureForwardMapping
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
void |
close() |
protected TensorMap |
convertFeatures(SparseVector input)
Converts from a SparseVector using the external model's indices into
the ingestion format for the external model.
|
protected TensorMap |
convertFeaturesList(List<SparseVector> input)
Converts from a list of SparseVector using the external model's indices
into the ingestion format for the external model.
|
protected List<Prediction<T>> |
convertOutput(org.tensorflow.Tensor output,
int[] numValidFeatures,
List<Example<T>> examples)
Converts a tensor into a prediction.
|
protected Prediction<T> |
convertOutput(org.tensorflow.Tensor output,
int numValidFeatures,
Example<T> example)
Converts a tensor into a prediction.
|
protected Model<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
static <T extends Output<T>> |
createTensorflowModel(OutputFactory<T> factory,
Map<String,Integer> featureMapping,
Map<T,Integer> outputMapping,
String inputName,
String outputName,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
String filename)
Creates a TensorflowFrozenExternalModel by loading in a frozen graph.
|
protected org.tensorflow.Tensor |
externalPrediction(TensorMap input)
Runs the session to make a prediction.
|
Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
createFeatureMap, createOutputInfo, getBatchSize, getExcuse, innerPredict, predict, setBatchSize
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
protected TensorMap convertFeatures(SparseVector input)
ExternalModel
convertFeatures
in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
input
- The features using external indices.protected TensorMap convertFeaturesList(List<SparseVector> input)
ExternalModel
convertFeaturesList
in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
input
- The features using external indices.protected org.tensorflow.Tensor externalPrediction(TensorMap input)
Closes the input tensor after the prediction has been made.
externalPrediction
in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
input
- The input in the external model's format.protected Prediction<T> convertOutput(org.tensorflow.Tensor output, int numValidFeatures, Example<T> example)
convertOutput
in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.Prediction
representing this tensor output.protected List<Prediction<T>> convertOutput(org.tensorflow.Tensor output, int[] numValidFeatures, List<Example<T>> examples)
convertOutput
in class ExternalModel<T extends Output<T>,TensorMap,org.tensorflow.Tensor>
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.Prediction
representing this tensor output.public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
Model
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
getTopFeatures
in class Model<T extends Output<T>>
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.protected Model<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
public void close()
close
in interface Closeable
close
in interface AutoCloseable
public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, String inputName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String filename)
T
- The type of the output.factory
- The output factory.featureMapping
- The feature mapping between Tribuo's names and the TF integer ids.outputMapping
- The output mapping between Tribuo's names and the TF integer ids.inputName
- The name of the input placeholder.outputName
- The name of the output tensor.featureConverter
- The feature transformation function.outputConverter
- The output transformation function.filename
- The filename to load the graph from.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.