Class ONNXExternalModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.ExternalModel<T, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
org.tribuo.interop.onnx.ONNXExternalModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, AutoCloseable

public final class ONNXExternalModel<T extends Output<T>> extends ExternalModel<T, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>> implements AutoCloseable
A Tribuo wrapper around a ONNX model.

N.B. ONNX support is experimental, and may change without a major version bump.

See Also:
  • Method Details

    • rebuild

      public void rebuild(ai.onnxruntime.OrtSession.SessionOptions newOptions) throws ai.onnxruntime.OrtException
      Closes the session and rebuilds it using the supplied options.

      Used to select a different backend, or change the number of inference threads etc.

      Parameters:
      newOptions - The new session options.
      Throws:
      ai.onnxruntime.OrtException - If the model failed to rebuild the session with the supplied options.
    • convertFeatures

      protected ai.onnxruntime.OnnxTensor convertFeatures(SparseVector input)
      Description copied from class: ExternalModel
      Converts from a SparseVector using the external model's indices into the ingestion format for the external model.
      Specified by:
      convertFeatures in class ExternalModel<T extends Output<T>, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • convertFeaturesList

      protected ai.onnxruntime.OnnxTensor convertFeaturesList(List<SparseVector> input)
      Description copied from class: ExternalModel
      Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.
      Specified by:
      convertFeaturesList in class ExternalModel<T extends Output<T>, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • externalPrediction

      protected List<ai.onnxruntime.OnnxValue> externalPrediction(ai.onnxruntime.OnnxTensor input)
      Runs the session to make a prediction.

      Closes the input tensor after the prediction has been made.

      Specified by:
      externalPrediction in class ExternalModel<T extends Output<T>, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
      Parameters:
      input - The input in the external model's format.
      Returns:
      A tensor representing the output.
    • convertOutput

      protected Prediction<T> convertOutput(List<ai.onnxruntime.OnnxValue> output, int numValidFeatures, Example<T> example)
      Converts a tensor into a prediction. Closes the output tensor after it's been converted.
      Specified by:
      convertOutput in class ExternalModel<T extends Output<T>, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
      Parameters:
      output - The output of the external model.
      numValidFeatures - The number of valid features in the input.
      example - The input example, used to construct the Prediction.
      Returns:
      A Prediction representing this tensor output.
    • convertOutput

      protected List<Prediction<T>> convertOutput(List<ai.onnxruntime.OnnxValue> output, int[] numValidFeatures, List<Example<T>> examples)
      Converts a tensor into a prediction. Closes the output tensor after it's been converted.
      Specified by:
      convertOutput in class ExternalModel<T extends Output<T>, ai.onnxruntime.OnnxTensor, List<ai.onnxruntime.OnnxValue>>
      Parameters:
      output - The output of the external model.
      numValidFeatures - An array with the number of valid features in each example.
      examples - The input examples, used to construct the Predictions.
      Returns:
      A list of Prediction representing this tensor output.
    • getTopFeatures

      public Map<String, List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Description copied from class: Model
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • copy

      protected Model<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing it's provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
    • createOnnxModel

      public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, ai.onnxruntime.OrtSession.SessionOptions opts, String filename, String inputName) throws ai.onnxruntime.OrtException
      Creates an ONNXExternalModel by loading the model from disk.
      Type Parameters:
      T - The type of the output.
      Parameters:
      factory - The output factory to use.
      featureMapping - The feature mapping between Tribuo names and ONNX integer ids.
      outputMapping - The output mapping between Tribuo outputs and ONNX integer ids.
      featureTransformer - The transformation function for the features.
      outputTransformer - The transformation function for the outputs.
      opts - The session options for the ONNX model.
      filename - The model path.
      inputName - The name of the input node.
      Returns:
      An ONNXExternalModel ready to score new inputs.
      Throws:
      ai.onnxruntime.OrtException - If the onnx-runtime native library call failed.
    • createOnnxModel

      public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, ai.onnxruntime.OrtSession.SessionOptions opts, Path path, String inputName) throws ai.onnxruntime.OrtException
      Creates an ONNXExternalModel by loading the model from disk.
      Type Parameters:
      T - The type of the output.
      Parameters:
      factory - The output factory to use.
      featureMapping - The feature mapping between Tribuo names and ONNX integer ids.
      outputMapping - The output mapping between Tribuo outputs and ONNX integer ids.
      featureTransformer - The transformation function for the features.
      outputTransformer - The transformation function for the outputs.
      opts - The session options for the ONNX model.
      path - The model path.
      inputName - The name of the input node.
      Returns:
      An ONNXExternalModel ready to score new inputs.
      Throws:
      ai.onnxruntime.OrtException - If the onnx-runtime native library call failed.