public final class ONNXExternalModel<T extends Output<T>> extends ExternalModel<T,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>> implements AutoCloseable
N.B. ONNX support is experimental, and may change without a major version bump.
DEFAULT_BATCH_SIZE, featureBackwardMapping, featureForwardMapping
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Modifier and Type | Method and Description |
---|---|
void |
close() |
protected ai.onnxruntime.OnnxTensor |
convertFeatures(SparseVector input)
Converts from a SparseVector using the external model's indices into
the ingestion format for the external model.
|
protected ai.onnxruntime.OnnxTensor |
convertFeaturesList(List<SparseVector> input)
Converts from a list of SparseVector using the external model's indices
into the ingestion format for the external model.
|
protected List<Prediction<T>> |
convertOutput(List<ai.onnxruntime.OnnxValue> output,
int[] numValidFeatures,
List<Example<T>> examples)
Converts a tensor into a prediction.
|
protected Prediction<T> |
convertOutput(List<ai.onnxruntime.OnnxValue> output,
int numValidFeatures,
Example<T> example)
Converts a tensor into a prediction.
|
protected Model<T> |
copy(String newName,
ModelProvenance newProvenance)
Copies a model, replacing it's provenance and name with the supplied values.
|
static <T extends Output<T>> |
createOnnxModel(OutputFactory<T> factory,
Map<String,Integer> featureMapping,
Map<T,Integer> outputMapping,
ExampleTransformer featureTransformer,
OutputTransformer<T> outputTransformer,
ai.onnxruntime.OrtSession.SessionOptions opts,
Path path,
String inputName)
Creates an
ONNXExternalModel by loading the model from disk. |
static <T extends Output<T>> |
createOnnxModel(OutputFactory<T> factory,
Map<String,Integer> featureMapping,
Map<T,Integer> outputMapping,
ExampleTransformer featureTransformer,
OutputTransformer<T> outputTransformer,
ai.onnxruntime.OrtSession.SessionOptions opts,
String filename,
String inputName)
Creates an
ONNXExternalModel by loading the model from disk. |
protected List<ai.onnxruntime.OnnxValue> |
externalPrediction(ai.onnxruntime.OnnxTensor input)
Runs the session to make a prediction.
|
Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
void |
rebuild(ai.onnxruntime.OrtSession.SessionOptions newOptions)
Closes the session and rebuilds it using the supplied options.
|
createFeatureMap, createOutputInfo, getBatchSize, getExcuse, innerPredict, predict, setBatchSize
copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, predict, predict, setName, toString, validate
public void rebuild(ai.onnxruntime.OrtSession.SessionOptions newOptions) throws ai.onnxruntime.OrtException
Used to select a different backend, or change the number of inference threads etc.
newOptions
- The new session options.ai.onnxruntime.OrtException
- If the model failed to rebuild the session with the supplied options.protected ai.onnxruntime.OnnxTensor convertFeatures(SparseVector input)
ExternalModel
convertFeatures
in class ExternalModel<T extends Output<T>,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>>
input
- The features using external indices.protected ai.onnxruntime.OnnxTensor convertFeaturesList(List<SparseVector> input)
ExternalModel
convertFeaturesList
in class ExternalModel<T extends Output<T>,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>>
input
- The features using external indices.protected List<ai.onnxruntime.OnnxValue> externalPrediction(ai.onnxruntime.OnnxTensor input)
Closes the input tensor after the prediction has been made.
externalPrediction
in class ExternalModel<T extends Output<T>,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>>
input
- The input in the external model's format.protected Prediction<T> convertOutput(List<ai.onnxruntime.OnnxValue> output, int numValidFeatures, Example<T> example)
convertOutput
in class ExternalModel<T extends Output<T>,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>>
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.Prediction
representing this tensor output.protected List<Prediction<T>> convertOutput(List<ai.onnxruntime.OnnxValue> output, int[] numValidFeatures, List<Example<T>> examples)
convertOutput
in class ExternalModel<T extends Output<T>,ai.onnxruntime.OnnxTensor,List<ai.onnxruntime.OnnxValue>>
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.Prediction
representing this tensor output.public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
Model
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
getTopFeatures
in class Model<T extends Output<T>>
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.protected Model<T> copy(String newName, ModelProvenance newProvenance)
Model
Used to provide the provenance removal functionality.
public void close()
close
in interface AutoCloseable
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, ai.onnxruntime.OrtSession.SessionOptions opts, String filename, String inputName) throws ai.onnxruntime.OrtException
ONNXExternalModel
by loading the model from disk.T
- The type of the output.factory
- The output factory to use.featureMapping
- The feature mapping between Tribuo names and ONNX integer ids.outputMapping
- The output mapping between Tribuo outputs and ONNX integer ids.featureTransformer
- The transformation function for the features.outputTransformer
- The transformation function for the outputs.opts
- The session options for the ONNX model.filename
- The model path.inputName
- The name of the input node.ai.onnxruntime.OrtException
- If the onnx-runtime native library call failed.public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, ai.onnxruntime.OrtSession.SessionOptions opts, Path path, String inputName) throws ai.onnxruntime.OrtException
ONNXExternalModel
by loading the model from disk.T
- The type of the output.factory
- The output factory to use.featureMapping
- The feature mapping between Tribuo names and ONNX integer ids.outputMapping
- The output mapping between Tribuo outputs and ONNX integer ids.featureTransformer
- The transformation function for the features.outputTransformer
- The transformation function for the outputs.opts
- The session options for the ONNX model.path
- The model path.inputName
- The name of the input node.ai.onnxruntime.OrtException
- If the onnx-runtime native library call failed.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.