Package org.tribuo

Interface ONNXExportable

All Known Implementing Classes:
FMClassificationModel, FMMultiLabelModel, FMRegressionModel, LibLinearClassificationModel, LibLinearRegressionModel, LibSVMClassificationModel, LibSVMRegressionModel, LinearSGDModel, LinearSGDModel, LinearSGDModel, SparseLinearModel, WeightedEnsembleModel

public interface ONNXExportable
An interface which denotes this Model can be exported as an ONNX model.

Tribuo models export with a single input of size [-1, numFeatures] and a single output of size [-1, numOutputDimensions]. The first dimension in both is defined to be an unbound dimension called "batch", which denotes the batch size.

ONNX exported models use floats where Tribuo uses doubles, this is due to comparatively poor support for fp64 in ONNX deployment environments as compared to fp32. In addition, fp32 executes better on the various accelerator backends available in ONNX Runtime.

  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    static final String
    The name of the ONNX metadata field where the provenance information is stored in exported models.
    static final com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization
    The provenance serializer.
  • Method Summary

    Modifier and Type
    Method
    Description
    static <M extends com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>>
    ai.onnx.proto.OnnxMl.ModelProto
    buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model)
    Creates an ONNX model protobuf for the supplied context.
    ai.onnx.proto.OnnxMl.ModelProto
    exportONNXModel(String domain, long modelVersion)
    Exports this Model as an ONNX protobuf.
    default void
    saveONNXModel(String domain, long modelVersion, Path outputPath)
    Exports this Model as an ONNX file.
    default String
    Serializes the model provenance to a String.
    Writes this Model into OnnxMl.GraphProto.Builder inside the input's ONNXContext.
  • Field Details

    • SERIALIZER

      static final com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization SERIALIZER
      The provenance serializer.
    • PROVENANCE_METADATA_FIELD

      static final String PROVENANCE_METADATA_FIELD
      The name of the ONNX metadata field where the provenance information is stored in exported models.
      See Also:
  • Method Details

    • buildModel

      static <M extends com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>> ai.onnx.proto.OnnxMl.ModelProto buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model)
      Creates an ONNX model protobuf for the supplied context.
      Type Parameters:
      M - The type of the provenanced model.
      Parameters:
      onnxContext - The context which contains the ONNX graph.
      domain - Domain for the produced model.
      modelVersion - Model version for the produced model.
      model - Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data from this model will be written into the ONNX Model proto.
      Returns:
      An ONNX model proto of the graph represented by the supplied ONNXContext.
    • exportONNXModel

      ai.onnx.proto.OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion)
      Exports this Model as an ONNX protobuf.
      Parameters:
      domain - A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
      modelVersion - A version number for this model.
      Returns:
      The ONNX ModelProto representing this Tribuo Model.
    • writeONNXGraph

      ONNXNode writeONNXGraph(ONNXRef<?> input)
      Writes this Model into OnnxMl.GraphProto.Builder inside the input's ONNXContext.
      Parameters:
      input - The input to the model graph.
      Returns:
      the output node of the model graph.
    • saveONNXModel

      default void saveONNXModel(String domain, long modelVersion, Path outputPath) throws IOException
      Exports this Model as an ONNX file.
      Parameters:
      domain - A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
      modelVersion - A version number for this model.
      outputPath - The path to write to.
      Throws:
      IOException - if the file could not be written to.
    • serializeProvenance

      default String serializeProvenance(ModelProvenance provenance)
      Serializes the model provenance to a String.
      Parameters:
      provenance - The provenance to serialize.
      Returns:
      The serialized form of the ModelProvenance.