public final class TensorflowExternalModel<T extends Output<T>> extends ExternalModel<T,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>> implements Closeable
The model's serialVersionUID is set to the major Tensorflow version number times 100.
N.B. Tensorflow support is experimental and may change without a major version bump.
DEFAULT_BATCH_SIZE, featureBackwardMapping, featureForwardMapping
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
void |
close() |
protected org.tensorflow.Tensor<?> |
convertFeatures(SparseVector input)
Converts from a SparseVector using the external model's indices into
the ingestion format for the external model.
|
protected org.tensorflow.Tensor<?> |
convertFeaturesList(List<SparseVector> input)
Converts from a list of SparseVector using the external model's indices
into the ingestion format for the external model.
|
protected List<Prediction<T>> |
convertOutput(org.tensorflow.Tensor<?> output,
int[] numValidFeatures,
List<Example<T>> examples)
Converts a tensor into a prediction.
|
protected Prediction<T> |
convertOutput(org.tensorflow.Tensor<?> output,
int numValidFeatures,
Example<T> example)
Converts a tensor into a prediction.
|
protected Model<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
static <T extends Output<T>> |
createTensorflowModel(OutputFactory<T> factory,
Map<String,Integer> featureMapping,
Map<T,Integer> outputMapping,
String inputName,
String outputName,
ExampleTransformer<T> featureTransformer,
OutputTransformer<T> outputTransformer,
String filename)
Creates a TensorflowExternalModel by loading in a frozen graph.
|
protected org.tensorflow.Tensor<?> |
externalPrediction(org.tensorflow.Tensor<?> input)
Runs the session to make a prediction.
|
Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
createFeatureMap, createOutputInfo, getBatchSize, getExcuse, innerPredict, predict, setBatchSize
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
protected org.tensorflow.Tensor<?> convertFeatures(SparseVector input)
ExternalModel
convertFeatures
in class ExternalModel<T extends Output<T>,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>>
input
- The features using external indices.protected org.tensorflow.Tensor<?> convertFeaturesList(List<SparseVector> input)
ExternalModel
convertFeaturesList
in class ExternalModel<T extends Output<T>,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>>
input
- The features using external indices.protected org.tensorflow.Tensor<?> externalPrediction(org.tensorflow.Tensor<?> input)
externalPrediction
in class ExternalModel<T extends Output<T>,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>>
input
- The input in the external model's format.protected Prediction<T> convertOutput(org.tensorflow.Tensor<?> output, int numValidFeatures, Example<T> example)
convertOutput
in class ExternalModel<T extends Output<T>,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>>
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.Prediction
representing this tensor output.protected List<Prediction<T>> convertOutput(org.tensorflow.Tensor<?> output, int[] numValidFeatures, List<Example<T>> examples)
convertOutput
in class ExternalModel<T extends Output<T>,org.tensorflow.Tensor<?>,org.tensorflow.Tensor<?>>
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.Prediction
representing this tensor output.public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
Model
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
getTopFeatures
in class Model<T extends Output<T>>
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.protected Model<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
public void close()
close
in interface Closeable
close
in interface AutoCloseable
public static <T extends Output<T>> TensorflowExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, String inputName, String outputName, ExampleTransformer<T> featureTransformer, OutputTransformer<T> outputTransformer, String filename)
T
- The type of the output.factory
- The output factory.featureMapping
- The feature mapping between Tribuo's names and the TF integer ids.outputMapping
- The output mapping between Tribuo's names and the TF integer ids.inputName
- The name of the input placeholder.outputName
- The name of the output tensor.featureTransformer
- The feature transformation function.outputTransformer
- The output transformation function.filename
- The filename to load the graph from.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.