This page describes Tribuo 4.2. View the documentation for Tribuo 4.3 instead.

Model export and deployment tutorial

Tribuo works best as a library which provides training and deployment inside the JVM where the application is running, however sometimes you need to deploy models elsewhere, either in another programming environment like Python, or in a cloud service. To support these use cases many of Tribuo's models can be exported as ONNX models, a cross-platform model exchange format. ONNX is widely supported across industry, for edge devices, hardware accelerators, and cloud services. Tribuo also supports loading in ONNX models and scoring them as native Tribuo models, for more information on that see the external models tutorial.

This tutorial will show how to export models in ONNX format, how to recover the provenance information from Tribuo-exported ONNX models, and how to deploy an ONNX model in OCI Data Science though of course other cloud providers support ONNX models too. We'll show how to export a factorization machine, create an ensemble of a factorization machine along with some other models, export the ensemble, then we'll discuss how to interact with the provenance of an exported model, before concluding with deploying that model to OCI.

Setup

This tutorial requires ONNX Runtime to score the exported models, so by default will only run on x86_64 platforms. ONNX Runtime can be compiled on ARM64 platforms, but that binary is not in the Maven Central jar Tribuo depends on, so will need to be compiled from scratch to run the tutorial on ARM.

We're going to use MNIST as the example dataset for this tutorial, so you'll need to download it if you haven't already.

First the training set:

wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz

wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Then the test set:

wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz

wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

As usual we'll load in some jars for classification problems, along with Tribuo's ONNX Runtime and OCI interfaces.

In [1]:
%jars ./tribuo-classification-experiments-4.2.0-jar-with-dependencies.jar
%jars ./tribuo-oci-4.2.0-jar-with-dependencies.jar
%jars ./tribuo-onnx-4.2.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.2.0-jar-with-dependencies.jar
In [2]:
import java.nio.file.Files;
import java.nio.file.Paths;

import org.tribuo.*;
import org.tribuo.classification.*;
import org.tribuo.classification.ensemble.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.fm.FMClassificationTrainer;
import org.tribuo.classification.sgd.linear.*;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.ensemble.*;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.interop.onnx.*;
import org.tribuo.math.optimisers.*;
import org.tribuo.interop.oci.*;
import org.tribuo.util.onnx.*;
import org.tribuo.util.Util;
import com.oracle.bmc.ConfigFileReader;
import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;
import com.oracle.bmc.datascience.DataScienceClient;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.util.Pair;

import ai.onnxruntime.*;

Then we'll load in MNIST and Wine Quality.

In [3]:
var labelFactory = new LabelFactory();
var labelEvaluator = new LabelEvaluator();
var mnistTrainSource = new IDXDataSource<>(Paths.get("train-images-idx3-ubyte.gz"),Paths.get("train-labels-idx1-ubyte.gz"),labelFactory);
var mnistTestSource = new IDXDataSource<>(Paths.get("t10k-images-idx3-ubyte.gz"),Paths.get("t10k-labels-idx1-ubyte.gz"),labelFactory);
var mnistTrain = new MutableDataset<>(mnistTrainSource);
var mnistTest = new MutableDataset<>(mnistTestSource);
System.out.println(String.format("MNIST train size = %d, number of features = %d, number of classes = %d",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));
System.out.println(String.format("MNIST test size = %d, number of features = %d, number of classes = %d",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));
MNIST train size = 60000, number of features = 717, number of classes = 10
MNIST test size = 10000, number of features = 668, number of classes = 10

Exporting a single classification model

We're going to train a multi-class Factorization Machine, which is a non-linear model that approximates all the non-linear feature interactions with a small per-feature embedding vector. It's similar to a logistic regression with an additional feature-feature interaction term, one per output label. In Tribuo Factorization Machines can be trained using stochastic gradient descent, using the standard SGD algorithms Tribuo uses for other models. We're going to use AdaGrad as it's usually a good baseline.

In [4]:
var fmLabelTrainer = new FMClassificationTrainer(new LogMulticlass(),  // Loss function
                                                 new AdaGrad(0.1,0.1), // Gradient optimiser
                                                 5,                    // Number of training epochs
                                                 30000,                // Logging interval
                                                 Trainer.DEFAULT_SEED, // RNG seed
                                                 6,                    // Factor size
                                                 0.1                   // Factor initialisation variance
                                                 );

After defining the model we train it as usual. Factorization machines take a little longer to train than logistic regression does, but not excessively so.

In [5]:
var fmStartTime = System.currentTimeMillis();
var fmMNIST = fmLabelTrainer.train(mnistTrain);
var fmEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));
Training factorization machine took (00:00:11:305)

And then evaluate it using Tribuo's built in evaluation system.

In [6]:
fmStartTime = System.currentTimeMillis();
var mnistFMEval = labelEvaluator.evaluate(fmMNIST,mnistTest);
fmEndTime = System.currentTimeMillis();
System.out.println("Scoring factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));
System.out.println(mnistFMEval.toString());
System.out.println(mnistFMEval.getConfusionMatrix().toString());
Scoring factorization machine took (00:00:00:412)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         959          21          31       0.979       0.969       0.974
1                           1,135       1,120          15          22       0.987       0.981       0.984
2                           1,032         976          56          57       0.946       0.945       0.945
3                           1,010         952          58          39       0.943       0.961       0.952
4                             982         952          30          49       0.969       0.951       0.960
5                             892         857          35          63       0.961       0.932       0.946
6                             958         920          38          30       0.960       0.968       0.964
7                           1,028         969          59          36       0.943       0.964       0.953
8                             974         916          58          57       0.940       0.941       0.941
9                           1,009         951          58          44       0.943       0.956       0.949
Total                      10,000       9,572         428         428
Accuracy                                                                    0.957
Micro Average                                                               0.957       0.957       0.957
Macro Average                                                               0.957       0.957       0.957
Balanced Error Rate                                                         0.043
               0       1       2       3       4       5       6       7       8       9
0            959       0       0       0       1       2       7       4       4       3
1              0   1,120       4       1       3       0       3       0       4       0
2              6       5     976       7       7       2       5       8      14       2
3              0       2      15     952       0      19       1       3      14       4
4              3       3       7       1     952       0       4       1       1      10
5              3       1       0       6       1     857       5       5      13       1
6              8       2       7       2       7      11     920       1       0       0
7              2       5      13       5       4       4       0     969       4      22
8              2       1       9       9      11      15       4       5     916       2
9              7       3       2       8      15      10       1       9       3     951

We get about 95% accuracy on MNIST, which is pretty good for a fairly simple model. Now let's export it to ONNX, then we'll load it back in via Tribuo's ONNX Runtime interface and compare the performance. We'll use this model in the reproducibility tutorial so we'll save it to disk in the tutorials folder.

Tribuo Models which support ONNX export implement the ONNXExportable interface which defines methods for constructing an ONNX protobuf and saving it to disk.

In [7]:
var fmMNISTPath = Paths.get(".","fm-mnist.onnx");
fmMNIST.saveONNXModel("org.tribuo.tutorials.onnxexport.fm", // namespace for the model
                      0,                                    // model version number
                      fmMNISTPath                           // path to save the model
                      );

To load an ONNX model we need to define the mapping between Tribuo's feature names and the indices that the ONNX model understands. Fortunately for models exported from Tribuo we already have that information, as it is stored in the feature and output maps. We'll extract it into the general form that ONNXExternalModel expects.

In [8]:
Map<String, Integer> mnistFeatureMap = new HashMap<>();
for (VariableInfo f : fmMNIST.getFeatureIDMap()){
    VariableIDInfo id = (VariableIDInfo) f;
    mnistFeatureMap.put(id.getName(),id.getID());
}
Map<Label, Integer> mnistOutputMap = new HashMap<>();
for (Pair<Integer,Label> l : fmMNIST.getOutputIDInfo()) {
    mnistOutputMap.put(l.getB(), l.getA());
}

Now we'll define a test function that compares two sets of predictions, as ONNX Runtime uses single precision for computations, and Tribuo uses double precision so the prediction scores are never bitwise equal.

In [9]:
public boolean checkPredictions(List<Prediction<Label>> nativePredictions, List<Prediction<Label>> onnxPredictions, double delta) {
    for (int i = 0; i < nativePredictions.size(); i++) {
        Prediction<Label> tribuo = nativePredictions.get(i);
        Prediction<Label> external = onnxPredictions.get(i);
        // Check the predicted label
        if (!tribuo.getOutput().getLabel().equals(external.getOutput().getLabel())) {
            System.out.println("At index " + i + " predictions are not equal - "
                    + tribuo.getOutput().getLabel() + " and "
                    + external.getOutput().getLabel());
            return false;
        }
        // Check the maximum score
        if (Math.abs(tribuo.getOutput().getScore() - external.getOutput().getScore()) > delta) {
            System.out.println("At index " + i + " predictions are not equal - "
                    + tribuo.getOutput() + " and "
                    + external.getOutput());
            return false;
        }
        // Check the score distribution
        for (Map.Entry<String, Label> l : tribuo.getOutputScores().entrySet()) {
            Label other = external.getOutputScores().get(l.getKey());
            if (other == null) {
                System.out.println("At index " + i + " failed to find label " + l.getKey() + " in ORT prediction.");
                return false;
            } else {
                if (Math.abs(l.getValue().getScore() - other.getScore()) > delta) {
                    System.out.println("At index " + i + " predictions are not equal - "
                            + tribuo.getOutputScores() + " and "
                            + external.getOutputScores());
                    return false;
                }
            }
        }
    }
    return true;
}

Then we'll construct the ONNXExternalModel loading our freshly created ONNX model using the feature and output mappings we built earlier. First we create a SessionOptions which controls the model inference. By default it uses a single thread on one CPU, but by setting values in the options object before building the external model we can make it run on multiple threads, use GPUs or other accelerator hardware supported by ONNX Runtime.

In [10]:
var ortEnv = OrtEnvironment.getEnvironment();
var sessionOpts = new OrtSession.SessionOptions();
var denseTransformer = new DenseTransformer();
var labelTransformer = new LabelTransformer();
ONNXExternalModel<Label> onnxFM = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
                    denseTransformer, labelTransformer, sessionOpts, fmMNISTPath, "input");

An ONNXExternalModel is a Tribuo model so we can use the same evaluation infrastructure.

In [11]:
var onnxStartTime = System.currentTimeMillis();
var mnistONNXEval = labelEvaluator.evaluate(onnxFM,mnistTest);
var onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX factorization machine took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println(mnistONNXEval.toString());
System.out.println(mnistONNXEval.getConfusionMatrix().toString());
Scoring ONNX factorization machine took (00:00:00:810)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         959          21          31       0.979       0.969       0.974
1                           1,135       1,120          15          22       0.987       0.981       0.984
2                           1,032         976          56          57       0.946       0.945       0.945
3                           1,010         952          58          39       0.943       0.961       0.952
4                             982         952          30          49       0.969       0.951       0.960
5                             892         857          35          63       0.961       0.932       0.946
6                             958         920          38          30       0.960       0.968       0.964
7                           1,028         969          59          36       0.943       0.964       0.953
8                             974         916          58          57       0.940       0.941       0.941
9                           1,009         951          58          44       0.943       0.956       0.949
Total                      10,000       9,572         428         428
Accuracy                                                                    0.957
Micro Average                                                               0.957       0.957       0.957
Macro Average                                                               0.957       0.957       0.957
Balanced Error Rate                                                         0.043
               0       1       2       3       4       5       6       7       8       9
0            959       0       0       0       1       2       7       4       4       3
1              0   1,120       4       1       3       0       3       0       4       0
2              6       5     976       7       7       2       5       8      14       2
3              0       2      15     952       0      19       1       3      14       4
4              3       3       7       1     952       0       4       1       1      10
5              3       1       0       6       1     857       5       5      13       1
6              8       2       7       2       7      11     920       1       0       0
7              2       5      13       5       4       4       0     969       4      22
8              2       1       9       9      11      15       4       5     916       2
9              7       3       2       8      15      10       1       9       3     951

The two models evaluate the same, but they could be producing slightly different probability values, so let's check it using our more precise comparsion function. checkPrediction will log any divergence it finds, as well as returning true or false if the predictions differ. We're going to use a delta of 1e-5, and consider differences below that threshold to be irrelevant.

In [12]:
System.out.println("Predictions are equal - " + 
                    checkPredictions(mnistFMEval.getPredictions(), mnistONNXEval.getPredictions(), 1e-5));
Predictions are equal - true

An important part of a Tribuo model is the provenance. We don't want to lose that information when exporting models to ONNX format, so we encode the provenance in the ONNX protobuf. It uses the marshalled provenance format from OLCUT, and the protos are available in OLCUT so they could be parsed in other systems. As a result when loading in a Tribuo-exported ONNX model the ONNXExternalModel class has two provenance objects, one for the ONNXExternalModel itself, and one for the original Model object.

Let's examine both of these provenances. First the one for the ONNXExternalModel:

In [13]:
System.out.println("ONNXExternalModel provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getProvenance()));
ONNXExternalModel provenance:
ONNXExternalModel(
	class-name = org.tribuo.interop.onnx.ONNXExternalModel
	dataset = Dataset(
			class-name = org.tribuo.Dataset
			datasource = DataSource(
					description = unknown-external-data
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					datasource-creation-time = 2021-12-18T20:36:37.266127-05:00
				)
			transformations = List[]
			is-sequence = false
			is-dense = false
			num-examples = -1
			num-features = 717
			num-outputs = 10
			tribuo-version = 4.2.0
		)
	trainer = Trainer(
			class-name = org.tribuo.Trainer
			fileModifiedTime = 2021-12-18T20:36:36.445-05:00
			modelHash = 06071247AEDE7539B899A2D530508D8E2B43304B8A7884A257368AA2CF1C18ED
			location = file:/Users/apocock/Development/Tribuo/tutorials/./fm-mnist.onnx
		)
	trained-at = 2021-12-18T20:36:37.263832-05:00
	instance-values = Map{
		model-domain=org.tribuo.tutorials.onnxexport.fm
		model-graphname=FMClassificationModel
		model-description=factorization-machine-model - Model(class-name=org.tribuo.classification.sgd.fm.FMClassificationModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=DataSource(class-name=org.tribuo.datasource.IDXDataSource,outputPath=/Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),featuresPath=/Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz,features-file-modified-time=2000-07-21T14:20:24-04:00,output-resource-hash=SHA-256[3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C],datasource-creation-time=2021-12-18T20:36:23.109293-05:00,output-file-modified-time=2000-07-21T14:20:27-04:00,idx-feature-type=UBYTE,features-resource-hash=SHA-256[440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609],host-short-name=DataSource),transformations=[],is-sequence=false,is-dense=false,num-examples=60000,num-features=717,num-outputs=10,tribuo-version=4.2.0),trainer=Trainer(class-name=org.tribuo.classification.sgd.fm.FMClassificationTrainer,seed=12345,variance=0.1,minibatchSize=1,factorizedDimSize=6,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=0.1,initialValue=0.0,host-short-name=StochasticGradientOptimiser),loggingInterval=30000,objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),tribuo-version=4.2.0,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2021-12-18T20:36:35.640663-05:00,instance-values={},tribuo-version=4.2.0,java-version=17.0.1,os-name=Mac OS X,os-arch=x86_64)
		model-producer=Tribuo
		model-version=0
		input-name=input
	}
	tribuo-version = 4.2.0
	java-version = 17.0.1
	os-name = Mac OS X
	os-arch = x86_64
)

This has the location the ONNX file was loaded from, a hash of the file, and timestamps for both the ONNX file and the model object wrapping it.

Now let's look at the original Model provenance:

In [14]:
System.out.println("ONNX file provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getTribuoProvenance().get()));
ONNX file provenance:
FMClassificationModel(
	class-name = org.tribuo.classification.sgd.fm.FMClassificationModel
	dataset = MutableDataset(
			class-name = org.tribuo.MutableDataset
			datasource = IDXDataSource(
					class-name = org.tribuo.datasource.IDXDataSource
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					outputPath = /Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz
					featuresPath = /Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz
					features-file-modified-time = 2000-07-21T14:20:24-04:00
					output-resource-hash = 3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C
					datasource-creation-time = 2021-12-18T20:36:23.109293-05:00
					output-file-modified-time = 2000-07-21T14:20:27-04:00
					idx-feature-type = UBYTE
					features-resource-hash = 440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609
					host-short-name = DataSource
				)
			transformations = List[]
			is-sequence = false
			is-dense = false
			num-examples = 60000
			num-features = 717
			num-outputs = 10
			tribuo-version = 4.2.0
		)
	trainer = FMClassificationTrainer(
			class-name = org.tribuo.classification.sgd.fm.FMClassificationTrainer
			seed = 12345
			variance = 0.1
			minibatchSize = 1
			factorizedDimSize = 6
			shuffle = true
			epochs = 5
			optimiser = AdaGrad(
					class-name = org.tribuo.math.optimisers.AdaGrad
					epsilon = 0.1
					initialLearningRate = 0.1
					initialValue = 0.0
					host-short-name = StochasticGradientOptimiser
				)
			loggingInterval = 30000
			objective = LogMulticlass(
					class-name = org.tribuo.classification.sgd.objectives.LogMulticlass
					host-short-name = LabelObjective
				)
			tribuo-version = 4.2.0
			train-invocation-count = 0
			is-sequence = false
			host-short-name = Trainer
		)
	trained-at = 2021-12-18T20:36:35.640663-05:00
	instance-values = Map{}
	tribuo-version = 4.2.0
	java-version = 17.0.1
	os-name = Mac OS X
	os-arch = x86_64
)

We can also check that the provenance extracted from the ONNX file is the same as the provenance in the original model object.

In [15]:
var equality = fmMNIST.getProvenance().equals(onnxFM.getTribuoProvenance().get()) ? "equal" : "not equal";
System.out.println("Provenances are " + equality);
Provenances are equal

Exporting an ensemble

Tribuo allows the creation of arbitrary ensembles, and these are usually powerful models which are useful to deploy. So we're going to make a 3 element voting ensemble out of our factorization machine along with two other models and export that to ONNX as well. The other models are a logistic regression and a smaller factorization machine, but we could use any classification model supported by Tribuo, including another ensemble. As this is a small ensemble of similar models our goal is to demonstrate the functionality rather than improve performance on MNIST too much.

In [16]:
var lrTrainer = new LogisticRegressionTrainer();
var smallFMTrainer = new FMClassificationTrainer(new LogMulticlass(),  // Loss function
                                                 new AdaGrad(0.1,0.1), // Gradient optimiser
                                                 2,                    // Number of training epochs
                                                 30000,                // Logging interval
                                                 42L,                  // RNG seed
                                                 3,                    // Factor size
                                                 0.1                   // Factor initialisation variance
                                                 );
var lrModel = lrTrainer.train(mnistTrain);
var smallFMModel = smallFMTrainer.train(mnistTrain);

Tribuo's WeightedEnsembleModel class allows the creation of arbitrary ensembles with or without voting weights. We're going to create an unweighted ensemble of our three models using the standard VotingCombiner which takes a majority vote between the three classes, with ties broken by the first label.

In [17]:
var ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("ensemble", // Model name
                                           List.of(fmMNIST,lrModel,smallFMModel), // Ensemble members
                                           new VotingCombiner());                 // Combination operator
In [18]:
var ensembleStartTime = System.currentTimeMillis();
var ensembleEval = labelEvaluator.evaluate(ensemble,mnistTest);
var ensembleEndTime = System.currentTimeMillis();
System.out.println("Scoring ensemble took " + Util.formatDuration(ensembleStartTime,ensembleEndTime));
System.out.println(ensembleEval.toString());
System.out.println(ensembleEval.getConfusionMatrix().toString());
Scoring ensemble took (00:00:00:675)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         965          15          43       0.985       0.957       0.971
1                           1,135       1,119          16          34       0.986       0.971       0.978
2                           1,032         979          53          86       0.949       0.919       0.934
3                           1,010         926          84          38       0.917       0.961       0.938
4                             982         937          45          49       0.954       0.950       0.952
5                             892         837          55          49       0.938       0.945       0.942
6                             958         922          36          32       0.962       0.966       0.964
7                           1,028         978          50          52       0.951       0.950       0.950
8                             974         918          56          98       0.943       0.904       0.923
9                           1,009         917          92          21       0.909       0.978       0.942
Total                      10,000       9,498         502         502
Accuracy                                                                    0.950
Micro Average                                                               0.950       0.950       0.950
Macro Average                                                               0.949       0.950       0.949
Balanced Error Rate                                                         0.051
               0       1       2       3       4       5       6       7       8       9
0            965       0       0       1       0       2       7       3       2       0
1              0   1,119       5       0       0       0       5       1       5       0
2              7       5     979       4       5       1       3       7      20       1
3              3       3      29     926       1      14       0       8      25       1
4              3       2      11       1     937       0       3       1      11      13
5              8       1       2       9       3     837      10       5      17       0
6              8       2       5       3       2      14     922       0       2       0
7              2       9      21       3       6       1       0     978       2       6
8              5       4      10       7      10       9       2       9     918       0
9              7       8       3      10      22       8       2      18      14     917

As before, we use the saveONNXModel method on the ONNXExportable interface to write out the model. Note if one of the ensemble members isn't ONNXExportable then you'll get a runtime exception out of this call.

In [19]:
var ensemblePath = Paths.get(".","ensemble-mnist.onnx");
ensemble.saveONNXModel("org.tribuo.tutorials.onnxexport.ensemble", // namespace for the model
                      0,                                           // model version number
                      ensemblePath                                 // path to save the model
                      );

We can load this model into ONNXExternalModel as well:

In [20]:
var onnxEnsemble = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
                    denseTransformer, labelTransformer, sessionOpts, ensemblePath, "input");
onnxStartTime = System.currentTimeMillis();
var mnistONNXEnsembleEval = labelEvaluator.evaluate(onnxEnsemble,mnistTest);
onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX ensemble took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println("Predictions are equal - " + 
                    checkPredictions(ensembleEval.getPredictions(), mnistONNXEnsembleEval.getPredictions(), 1e-5));
Scoring ONNX ensemble took (00:00:01:021)
Predictions are equal - true

Deploying the model

This portion of the tutorial describes how to deploy the ONNX model on OCI Data Science, using their model deployment service. ONNX models can also be deployed in many other machine learning cloud services, or via a functions-as-a-service offering using something like ONNX Runtime. ONNX models can also be deployed using Oracle Machine Learning Services, or in many other environments, including other cloud providers.

Tribuo's OCI Data Science support comes in two parts, a set of static methods for deploying models on the cloud, and the OCIModel class which wraps a model endpoint and allows using it as a normal Tribuo model. Underneath the covers we're going to use an OCI DS conda environment which contains ONNX Runtime in Python, and use that to make predictions from our model trained in Java.

To run this part of the tutorial you'll need to have configured your access to OCI Data Science (if you've not done this before then you can see a tutorial on how to do that here), setup authentication to allow CLI access to OCI and you'll need the compartment & project ids for the OCI Data Science project you want to deploy into.

In [21]:
// Set these variables appropriately for your OCI account
var compartmentID = "your-oci-compartment-id";
var projectID = "your-oci-ds-project-id";

Now we'll instantiate the DS client, and build the config object which captures all the information about the model we're uploading. The models are run inside a conda environment, and you need to select one which contains ONNX Runtime 1.6.0 or newer (as Tribuo emits ONNX models using Opset 13, which is supported in ONNX Runtime 1.6+). This can either be a custom one you've created, or one provided by OCI Data Science.

In [29]:
// Instantiate the client
var provider = new ConfigFileAuthenticationDetailsProvider(ConfigFileReader.parseDefault());
var dsClient = new DataScienceClient(provider);

// Instantiate an ObjectMapper for parsing the REST calls
var objMapper = OCIUtil.createObjectMapper();

// Select the conda environment
var condaName = "dataexpl_p37_cpu_v3"; // Also referred to as the "slug" in the OCI DS docs
var condaPath = "oci://service-conda-packs@id19sfcrra6z/service_pack/cpu/Data Exploration and Manipulation for CPU Python 3.7/3.0/dataexpl_p37_cpu_v3";

// Instantiate the model configuration
var dsConfig = new OCIUtil.OCIDSConfig(compartmentID,projectID);
var modelConfig = new OCIUtil.OCIModelArtifactConfig(dsConfig,          // Data Science config
                                             "tribuo-tutorial-model",   // Model name
                                             "A factorization machine", // Model description
                                             "org.tribuo.tutorial.test",// ONNX model domain
                                             0,                         // ONNX model version
                                             condaName,                 // Conda environment name
                                             condaPath);                // Conda environment path on object storage

We can now upload the model into OCI Data Science. The createModel method has an overload that accepts an ONNX file on disk, or you can pass in any model which implements ONNXExportable. Tribuo takes care of setting the model metadata according to the information it can extract from the Model object, and it automatically generates the necessary python script and yaml file which control the model's environment in the deployment. Note models are distinct from model deployments, so a single model artifact can be deployed multiple times with different endpoints, VM sizes and scaling parameters.

In [30]:
var modelID = OCIUtil.createModel(fmMNIST,dsClient,objMapper,modelConfig);

The modelID is the reference for the model artifact stored in Oracle Cloud, and we'll need this to create a deployment wrapping the model.

To specify the model deployment configuration there's a OCIModelDeploymentConfig wrapper class, it contains the model ID, the model deployment name, the VM shape, maximum number of VM instances to create, and the bandwidth available for that model. At time of writing OCI DS supports the VM.Standard2 shapes.

In [31]:
var deployConfig = new OCIUtil.OCIModelDeploymentConfig(dsConfig,modelID,"tribuo-tutorial-deployment","VM.Standard2.1",10,1);

var deployURL = OCIUtil.deploy(deployConfig,dsClient,objMapper);
System.out.println(deployURL);

Model deployments take a few minutes, so you'll need to wait a while if you've been following along with the tutorial. The deployment progress can be checked on the OCI console for the data science project you are using.

Once the deployment has finished, we can wrap it in an OCIModel and then check it's the same as the factorization machine we deployed. An OCIModel is a subclass of ExternalModel in the same way that externally trained ONNX models are, so we need to supply the mapping between Tribuo's feature domain & the feature indices expected by the model, the output domain mapping, and a OCIOutputConverter instance which can convert the prediction matrix into Tribuo's Prediction objects. As we've deployed a factorization machine for MNIST, we'll use OCILabelConverter, and the mappings are the same as the ones we used for the ONNX model earlier.

In [32]:
var ociLabelConverter = new OCILabelConverter(true);
var ociModel = OCIModel.createOCIModel(labelFactory,mnistFeatureMap, mnistOutputMap, 
                                       Paths.get("~/.oci/config"), // OCI authentication config
                                       deployURL,                  // Model endpoint URL
                                       ociLabelConverter);         // Output converter

As OCIModel is a Tribuo model we can evaluate it using our standard tools.

In [33]:
var ociStartTime = System.currentTimeMillis();
var ociEval = labelEvaluator.evaluate(ociModel,mnistTest);
var ociEndTime = System.currentTimeMillis();
System.out.println("Scoring OCI model took " + Util.formatDuration(ociStartTime,ociEndTime));
System.out.println(ociEval.toString());
System.out.println(ociEval.getConfusionMatrix().toString());

System.out.println("Predictions are equal - " + 
                    checkPredictions(ociEval.getPredictions(), mnistFMEval.getPredictions(), 1e-5));
Scoring OCI model took (00:01:06:960)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         959          21          31       0.979       0.969       0.974
1                           1,135       1,120          15          22       0.987       0.981       0.984
2                           1,032         976          56          57       0.946       0.945       0.945
3                           1,010         952          58          39       0.943       0.961       0.952
4                             982         952          30          49       0.969       0.951       0.960
5                             892         857          35          63       0.961       0.932       0.946
6                             958         920          38          30       0.960       0.968       0.964
7                           1,028         969          59          36       0.943       0.964       0.953
8                             974         916          58          57       0.940       0.941       0.941
9                           1,009         951          58          44       0.943       0.956       0.949
Total                      10,000       9,572         428         428
Accuracy                                                                    0.957
Micro Average                                                               0.957       0.957       0.957
Macro Average                                                               0.957       0.957       0.957
Balanced Error Rate                                                         0.043
               0       1       2       3       4       5       6       7       8       9
0            959       0       0       0       1       2       7       4       4       3
1              0   1,120       4       1       3       0       3       0       4       0
2              6       5     976       7       7       2       5       8      14       2
3              0       2      15     952       0      19       1       3      14       4
4              3       3       7       1     952       0       4       1       1      10
5              3       1       0       6       1     857       5       5      13       1
6              8       2       7       2       7      11     920       1       0       0
7              2       5      13       5       4       4       0     969       4      22
8              2       1       9       9      11      15       4       5     916       2
9              7       3       2       8      15      10       1       9       3     951

Predictions are equal - true

We can see that the model performs identically to the Tribuo version, though it takes a little longer as each call to predict incurs some network latency.

Conclusion

We've looked at exporting models out of Tribuo in ONNX format, where they can be used in different languages, runtimes and deployed in cloud environments like OCI Data Science. Over time we plan to expand Tribuo's support for ONNX export to cover more models. Tribuo's ONNX support is a separate module from the rest of Tribuo and could be used to build ONNX models in other packages on the JVM. If you're interested in expanding the support for ONNX in Java, you can open a Github issue for Tribuo, or you can talk to the ONNX community in their Slack workspace.