Package org.tribuo
Interface ONNXExportable
- All Known Implementing Classes:
FMClassificationModel
,FMMultiLabelModel
,FMRegressionModel
,LibLinearClassificationModel
,LibLinearRegressionModel
,LibSVMClassificationModel
,LibSVMRegressionModel
,LinearSGDModel
,LinearSGDModel
,LinearSGDModel
,SparseLinearModel
,WeightedEnsembleModel
public interface ONNXExportable
An interface which denotes this
Model
can be
exported as an ONNX model.
Tribuo models export with a single input of size [-1, numFeatures] and a single output of size [-1, numOutputDimensions]. The first dimension in both is defined to be an unbound dimension called "batch", which denotes the batch size.
ONNX exported models use floats where Tribuo uses doubles, this is due to comparatively poor support for fp64 in ONNX deployment environments as compared to fp32. In addition, fp32 executes better on the various accelerator backends available in ONNX Runtime.
-
Field Summary
Modifier and TypeFieldDescriptionstatic final String
The name of the ONNX metadata field where the provenance information is stored in exported models.static final com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization
The provenance serializer. -
Method Summary
Modifier and TypeMethodDescriptionstatic <M extends com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>>
ai.onnx.proto.OnnxMl.ModelProtobuildModel
(ONNXContext onnxContext, String domain, long modelVersion, M model) Creates an ONNX model protobuf for the supplied context.ai.onnx.proto.OnnxMl.ModelProto
exportONNXModel
(String domain, long modelVersion) Exports thisModel
as an ONNX protobuf.default void
saveONNXModel
(String domain, long modelVersion, Path outputPath) Exports thisModel
as an ONNX file.default String
serializeProvenance
(ModelProvenance provenance) Serializes the model provenance to a String.writeONNXGraph
(ONNXRef<?> input)
-
Field Details
-
SERIALIZER
static final com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization SERIALIZERThe provenance serializer. -
PROVENANCE_METADATA_FIELD
The name of the ONNX metadata field where the provenance information is stored in exported models.- See Also:
-
-
Method Details
-
buildModel
static <M extends com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>> ai.onnx.proto.OnnxMl.ModelProto buildModel(ONNXContext onnxContext, String domain, long modelVersion, M model) Creates an ONNX model protobuf for the supplied context.- Type Parameters:
M
- The type of the provenanced model.- Parameters:
onnxContext
- The context which contains the ONNX graph.domain
- Domain for the produced model.modelVersion
- Model version for the produced model.model
- Provenanced Tribuo model from which this model is derived - the DocString and Tribuo Provenance data from this model will be written into the ONNX Model proto.- Returns:
- An ONNX model proto of the graph represented by the supplied ONNXContext.
-
exportONNXModel
Exports thisModel
as an ONNX protobuf.- Parameters:
domain
- A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).modelVersion
- A version number for this model.- Returns:
- The ONNX ModelProto representing this Tribuo Model.
-
writeONNXGraph
- Parameters:
input
- The input to the model graph.- Returns:
- the output node of the model graph.
-
saveONNXModel
Exports thisModel
as an ONNX file.- Parameters:
domain
- A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).modelVersion
- A version number for this model.outputPath
- The path to write to.- Throws:
IOException
- if the file could not be written to.
-
serializeProvenance
Serializes the model provenance to a String.- Parameters:
provenance
- The provenance to serialize.- Returns:
- The serialized form of the ModelProvenance.
-