Model export and deployment tutorial¶
Tribuo works best as a library which provides training and deployment inside the JVM where the application is running, however sometimes you need to deploy models elsewhere, either in another programming environment like Python, or in a cloud service. To support these use cases many of Tribuo's models can be exported as ONNX models, a cross-platform model exchange format. ONNX is widely supported across industry, for edge devices, hardware accelerators, and cloud services. Tribuo also supports loading in ONNX models and scoring them as native Tribuo models, for more information on that see the external models tutorial.
This tutorial will show how to export models in ONNX format, how to recover the provenance information from Tribuo-exported ONNX models, and how to deploy an ONNX model in OCI Data Science though of course other cloud providers support ONNX models too. We'll show how to export a factorization machine, create an ensemble of a factorization machine along with some other models, export the ensemble, then we'll discuss how to interact with the provenance of an exported model, before concluding with deploying that model to OCI.
Setup¶
This tutorial requires ONNX Runtime to score the exported models, so by default will only run on x86_64 platforms. ONNX Runtime can be compiled on ARM64 platforms, but that binary is not in the Maven Central jar Tribuo depends on, so will need to be compiled from scratch to run the tutorial on ARM.
We're going to use MNIST as the example dataset for this tutorial, so you'll need to download it if you haven't already.
First the training set:
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Then the test set:
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
As usual we'll load in some jars for classification problems, along with Tribuo's ONNX Runtime and OCI interfaces.
%jars ./tribuo-classification-experiments-4.3.0-jar-with-dependencies.jar
%jars ./tribuo-oci-4.3.0-jar-with-dependencies.jar
%jars ./tribuo-onnx-4.3.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.3.0-jar-with-dependencies.jar
import java.nio.file.Files;
import java.nio.file.Paths;
import org.tribuo.*;
import org.tribuo.classification.*;
import org.tribuo.classification.ensemble.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.fm.FMClassificationTrainer;
import org.tribuo.classification.sgd.linear.*;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.ensemble.*;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.interop.onnx.*;
import org.tribuo.math.optimisers.*;
import org.tribuo.interop.oci.*;
import org.tribuo.util.onnx.*;
import org.tribuo.util.Util;
import com.oracle.bmc.ConfigFileReader;
import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;
import com.oracle.bmc.datascience.DataScienceClient;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.util.Pair;
import ai.onnxruntime.*;
Then we'll load in MNIST and Wine Quality.
var labelFactory = new LabelFactory();
var labelEvaluator = new LabelEvaluator();
var mnistTrainSource = new IDXDataSource<>(Paths.get("train-images-idx3-ubyte.gz"),Paths.get("train-labels-idx1-ubyte.gz"),labelFactory);
var mnistTestSource = new IDXDataSource<>(Paths.get("t10k-images-idx3-ubyte.gz"),Paths.get("t10k-labels-idx1-ubyte.gz"),labelFactory);
var mnistTrain = new MutableDataset<>(mnistTrainSource);
var mnistTest = new MutableDataset<>(mnistTestSource);
System.out.println(String.format("MNIST train size = %d, number of features = %d, number of classes = %d",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));
System.out.println(String.format("MNIST test size = %d, number of features = %d, number of classes = %d",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));
Exporting a single classification model¶
We're going to train a multi-class Factorization Machine, which is a non-linear model that approximates all the non-linear feature interactions with a small per-feature embedding vector. It's similar to a logistic regression with an additional feature-feature interaction term, one per output label. In Tribuo Factorization Machines can be trained using stochastic gradient descent, using the standard SGD algorithms Tribuo uses for other models. We're going to use AdaGrad as it's usually a good baseline.
var fmLabelTrainer = new FMClassificationTrainer(new LogMulticlass(), // Loss function
new AdaGrad(0.1,0.1), // Gradient optimiser
5, // Number of training epochs
30000, // Logging interval
Trainer.DEFAULT_SEED, // RNG seed
6, // Factor size
0.1 // Factor initialisation variance
);
After defining the model we train it as usual. Factorization machines take a little longer to train than logistic regression does, but not excessively so.
var fmStartTime = System.currentTimeMillis();
var fmMNIST = fmLabelTrainer.train(mnistTrain);
var fmEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));
And then evaluate it using Tribuo's built in evaluation system.
fmStartTime = System.currentTimeMillis();
var mnistFMEval = labelEvaluator.evaluate(fmMNIST,mnistTest);
fmEndTime = System.currentTimeMillis();
System.out.println("Scoring factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));
System.out.println(mnistFMEval.toString());
System.out.println(mnistFMEval.getConfusionMatrix().toString());
We get about 95% accuracy on MNIST, which is pretty good for a fairly simple model. Now let's export it to ONNX, then we'll load it back in via Tribuo's ONNX Runtime interface and compare the performance. We'll use this model in the reproducibility tutorial so we'll save it to disk in the tutorials folder.
Tribuo Model
s which support ONNX export implement the ONNXExportable
interface which defines methods for constructing an ONNX protobuf and saving it to disk.
var fmMNISTPath = Paths.get(".","fm-mnist.onnx");
fmMNIST.saveONNXModel("org.tribuo.tutorials.onnxexport.fm", // namespace for the model
0, // model version number
fmMNISTPath // path to save the model
);
To load an ONNX model we need to define the mapping between Tribuo's feature names and the indices that the ONNX model understands. Fortunately for models exported from Tribuo we already have that information, as it is stored in the feature and output maps. We'll extract it into the general form that ONNXExternalModel
expects.
Map<String, Integer> mnistFeatureMap = new HashMap<>();
for (VariableInfo f : fmMNIST.getFeatureIDMap()){
VariableIDInfo id = (VariableIDInfo) f;
mnistFeatureMap.put(id.getName(),id.getID());
}
Map<Label, Integer> mnistOutputMap = new HashMap<>();
for (Pair<Integer,Label> l : fmMNIST.getOutputIDInfo()) {
mnistOutputMap.put(l.getB(), l.getA());
}
Now we'll define a test function that compares two sets of predictions, as ONNX Runtime uses single precision for computations, and Tribuo uses double precision so the prediction scores are never bitwise equal.
public boolean checkPredictions(List<Prediction<Label>> nativePredictions, List<Prediction<Label>> onnxPredictions, double delta) {
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Label> tribuo = nativePredictions.get(i);
Prediction<Label> external = onnxPredictions.get(i);
// Check the predicted label
if (!tribuo.getOutput().getLabel().equals(external.getOutput().getLabel())) {
System.out.println("At index " + i + " predictions are not equal - "
+ tribuo.getOutput().getLabel() + " and "
+ external.getOutput().getLabel());
return false;
}
// Check the maximum score
if (Math.abs(tribuo.getOutput().getScore() - external.getOutput().getScore()) > delta) {
System.out.println("At index " + i + " predictions are not equal - "
+ tribuo.getOutput() + " and "
+ external.getOutput());
return false;
}
// Check the score distribution
for (Map.Entry<String, Label> l : tribuo.getOutputScores().entrySet()) {
Label other = external.getOutputScores().get(l.getKey());
if (other == null) {
System.out.println("At index " + i + " failed to find label " + l.getKey() + " in ORT prediction.");
return false;
} else {
if (Math.abs(l.getValue().getScore() - other.getScore()) > delta) {
System.out.println("At index " + i + " predictions are not equal - "
+ tribuo.getOutputScores() + " and "
+ external.getOutputScores());
return false;
}
}
}
}
return true;
}
Then we'll construct the ONNXExternalModel
loading our freshly created ONNX model using the feature and output mappings we built earlier. First we create a SessionOptions
which controls the model inference. By default it uses a single thread on one CPU, but by setting values in the options object before building the external model we can make it run on multiple threads, use GPUs or other accelerator hardware supported by ONNX Runtime.
var ortEnv = OrtEnvironment.getEnvironment();
var sessionOpts = new OrtSession.SessionOptions();
var denseTransformer = new DenseTransformer();
var labelTransformer = new LabelTransformer();
ONNXExternalModel<Label> onnxFM = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
denseTransformer, labelTransformer, sessionOpts, fmMNISTPath, "input");
An ONNXExternalModel
is a Tribuo model so we can use the same evaluation infrastructure.
var onnxStartTime = System.currentTimeMillis();
var mnistONNXEval = labelEvaluator.evaluate(onnxFM,mnistTest);
var onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX factorization machine took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println(mnistONNXEval.toString());
System.out.println(mnistONNXEval.getConfusionMatrix().toString());
The two models evaluate the same, but they could be producing slightly different probability values, so let's check it using our more precise comparsion function. checkPrediction
will log any divergence it finds, as well as returning true or false if the predictions differ. We're going to use a delta of 1e-5, and consider differences below that threshold to be irrelevant.
System.out.println("Predictions are equal - " +
checkPredictions(mnistFMEval.getPredictions(), mnistONNXEval.getPredictions(), 1e-5));
An important part of a Tribuo model is the provenance. We don't want to lose that information when exporting models to ONNX format, so we encode the provenance in the ONNX protobuf. It uses the marshalled provenance format from OLCUT, and the protos are available in OLCUT so they could be parsed in other systems. As a result when loading in a Tribuo-exported ONNX model the ONNXExternalModel
class has two provenance objects, one for the ONNXExternalModel
itself, and one for the original Model
object.
Let's examine both of these provenances. First the one for the ONNXExternalModel
:
System.out.println("ONNXExternalModel provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getProvenance()));
This has the location the ONNX file was loaded from, a hash of the file, and timestamps for both the ONNX file and the model object wrapping it.
Now let's look at the original Model provenance:
System.out.println("ONNX file provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getTribuoProvenance().get()));
We can also check that the provenance extracted from the ONNX file is the same as the provenance in the original model object.
var equality = fmMNIST.getProvenance().equals(onnxFM.getTribuoProvenance().get()) ? "equal" : "not equal";
System.out.println("Provenances are " + equality);
Exporting an ensemble¶
Tribuo allows the creation of arbitrary ensembles, and these are usually powerful models which are useful to deploy. So we're going to make a 3 element voting ensemble out of our factorization machine along with two other models and export that to ONNX as well. The other models are a logistic regression and a smaller factorization machine, but we could use any classification model supported by Tribuo, including another ensemble. As this is a small ensemble of similar models our goal is to demonstrate the functionality rather than improve performance on MNIST too much.
var lrTrainer = new LogisticRegressionTrainer();
var smallFMTrainer = new FMClassificationTrainer(new LogMulticlass(), // Loss function
new AdaGrad(0.1,0.1), // Gradient optimiser
2, // Number of training epochs
30000, // Logging interval
42L, // RNG seed
3, // Factor size
0.1 // Factor initialisation variance
);
var lrModel = lrTrainer.train(mnistTrain);
var smallFMModel = smallFMTrainer.train(mnistTrain);
Tribuo's WeightedEnsembleModel
class allows the creation of arbitrary ensembles with or without voting weights. We're going to create an unweighted ensemble of our three models using the standard VotingCombiner
which takes a majority vote between the three classes, with ties broken by the first label.
var ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("ensemble", // Model name
List.of(fmMNIST,lrModel,smallFMModel), // Ensemble members
new VotingCombiner()); // Combination operator
var ensembleStartTime = System.currentTimeMillis();
var ensembleEval = labelEvaluator.evaluate(ensemble,mnistTest);
var ensembleEndTime = System.currentTimeMillis();
System.out.println("Scoring ensemble took " + Util.formatDuration(ensembleStartTime,ensembleEndTime));
System.out.println(ensembleEval.toString());
System.out.println(ensembleEval.getConfusionMatrix().toString());
As before, we use the saveONNXModel
method on the ONNXExportable
interface to write out the model. Note if one of the ensemble members isn't ONNXExportable
then you'll get a runtime exception out of this call.
var ensemblePath = Paths.get(".","ensemble-mnist.onnx");
ensemble.saveONNXModel("org.tribuo.tutorials.onnxexport.ensemble", // namespace for the model
0, // model version number
ensemblePath // path to save the model
);
We can load this model into ONNXExternalModel
as well:
var onnxEnsemble = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
denseTransformer, labelTransformer, sessionOpts, ensemblePath, "input");
onnxStartTime = System.currentTimeMillis();
var mnistONNXEnsembleEval = labelEvaluator.evaluate(onnxEnsemble,mnistTest);
onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX ensemble took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println("Predictions are equal - " +
checkPredictions(ensembleEval.getPredictions(), mnistONNXEnsembleEval.getPredictions(), 1e-5));
Deploying the model¶
This portion of the tutorial describes how to deploy the ONNX model on OCI Data Science, using their model deployment service. ONNX models can also be deployed in many other machine learning cloud services, or via a functions-as-a-service offering using something like ONNX Runtime. ONNX models can also be deployed using Oracle Machine Learning Services, or in many other environments, including other cloud providers.
Tribuo's OCI Data Science support comes in two parts, a set of static methods for deploying models in the cloud, and the OCIModel
class which wraps a model endpoint and allows using it as a normal Tribuo model. Underneath the covers we're going to use an OCI DS conda environment which contains ONNX Runtime in Python, and use that to make predictions from our model trained in Java.
To run this part of the tutorial you'll need to have configured your access to OCI Data Science (if you've not done this before then you can see a tutorial on how to do that here), setup authentication to allow CLI access to OCI and you'll need the compartment & project ids for the OCI Data Science project you want to deploy into.
// Set these variables appropriately for your OCI account
var compartmentID = "your-oci-compartment-id";
var projectID = "your-oci-ds-project-id";
Now we'll instantiate the DS client, and build the config object which captures all the information about the model we're uploading. The models are run inside a conda environment, and you need to select one which contains ONNX Runtime 1.6.0 or newer (as Tribuo emits ONNX models using Opset 13, which is supported in ONNX Runtime 1.6+). This can either be a custom one you've created, or one provided by OCI Data Science.
// Instantiate the client
var provider = new ConfigFileAuthenticationDetailsProvider(ConfigFileReader.parseDefault());
var dsClient = new DataScienceClient(provider);
// Instantiate an ObjectMapper for parsing the REST calls
var objMapper = OCIUtil.createObjectMapper();
// Select the conda environment
var condaName = "dataexpl_p37_cpu_v3"; // Also referred to as the "slug" in the OCI DS docs
var condaPath = "oci://service-conda-packs@id19sfcrra6z/service_pack/cpu/Data Exploration and Manipulation for CPU Python 3.7/3.0/dataexpl_p37_cpu_v3";
// Instantiate the model configuration
var dsConfig = new OCIUtil.OCIDSConfig(compartmentID,projectID);
var modelConfig = new OCIUtil.OCIModelArtifactConfig(dsConfig, // Data Science config
"tribuo-tutorial-model", // Model name
"A factorization machine", // Model description
"org.tribuo.tutorial.test",// ONNX model domain
0, // ONNX model version
condaName, // Conda environment name
condaPath); // Conda environment path on object storage
We can now upload the model into OCI Data Science. The createModel
method has an overload that accepts an ONNX file on disk, or you can pass in any model which implements ONNXExportable
. Tribuo takes care of setting the model metadata according to the information it can extract from the Model
object, and it automatically generates the necessary python script and yaml file which control the model's environment in the deployment. Note models are distinct from model deployments, so a single model artifact can be deployed multiple times with different endpoints, VM sizes and scaling parameters.
var modelID = OCIUtil.createModel(fmMNIST,dsClient,objMapper,modelConfig);
The modelID
is the reference for the model artifact stored in Oracle Cloud, and we'll need this to create a deployment wrapping the model.
To specify the model deployment configuration there's a OCIModelDeploymentConfig
wrapper class, it contains the model ID, the model deployment name, the VM shape, maximum number of VM instances to create, and the bandwidth available for that model. At time of writing OCI DS supports the VM.Standard2
shapes.
var deployConfig = new OCIUtil.OCIModelDeploymentConfig(dsConfig,modelID,"tribuo-tutorial-deployment","VM.Standard2.1",10,1);
var deployURL = OCIUtil.deploy(deployConfig,dsClient,objMapper);
System.out.println(deployURL);
Model deployments take a few minutes, so you'll need to wait a while if you've been following along with the tutorial. The deployment progress can be checked on the OCI console for the data science project you are using.
Once the deployment has finished, we can wrap it in an OCIModel
and then check it's the same as the factorization machine we deployed. An OCIModel
is a subclass of ExternalModel
in the same way that externally trained ONNX models are, so we need to supply the mapping between Tribuo's feature domain & the feature indices expected by the model, the output domain mapping, and a OCIOutputConverter
instance which can convert the prediction matrix into Tribuo's Prediction
objects. As we've deployed a factorization machine for MNIST, we'll use OCILabelConverter
, and the mappings are the same as the ones we used for the ONNX model earlier.
var ociLabelConverter = new OCILabelConverter(true);
var ociModel = OCIModel.createOCIModel(labelFactory,mnistFeatureMap, mnistOutputMap,
Paths.get("~/.oci/config"), // OCI authentication config
deployURL, // Model endpoint URL
ociLabelConverter); // Output converter
As OCIModel
is a Tribuo model we can evaluate it using our standard tools.
Note when running this notebook from scratch the OCI Model Deployment can take up to 15 minutes to fully instantiate, and the next cell will not execute correctly until that deployment has finished. You can monitor the status of the deployment in the OCI console.
var ociStartTime = System.currentTimeMillis();
var ociEval = labelEvaluator.evaluate(ociModel,mnistTest);
var ociEndTime = System.currentTimeMillis();
System.out.println("Scoring OCI model took " + Util.formatDuration(ociStartTime,ociEndTime));
System.out.println(ociEval.toString());
System.out.println(ociEval.getConfusionMatrix().toString());
System.out.println("Predictions are equal - " +
checkPredictions(ociEval.getPredictions(), mnistFMEval.getPredictions(), 1e-5));
We can see that the model performs identically to the Tribuo version, though it takes a little longer as each call to predict incurs some network latency.
Conclusion¶
We've looked at exporting models out of Tribuo in ONNX format, where they can be used in different languages, runtimes and deployed in cloud environments like OCI Data Science. Over time we plan to expand Tribuo's support for ONNX export to cover more models. Tribuo's ONNX support is a separate module from the rest of Tribuo and could be used to build ONNX models in other packages on the JVM. If you're interested in expanding the support for ONNX in Java, you can open a Github issue for Tribuo, or you can talk to the ONNX community in their Slack workspace.