Working with external models

Tribuo can load in models trained in third party systems and deploy them alongside native Tribuo models. In Tribuo 4.1 we support models trained externally in XGBoost, TensorFlow frozen graphs & saved models, and models stored in ONNX (Open Neural Network eXchange) format. The latter is particularly interesting for Tribuo as many libraries can export models in ONNX format, such as scikit-learn, pytorch, TensorFlow among others. For a more complete list of the supported onnx models you can look at the ONNX website. Tribuo's ONNX support is supplied by ONNX Runtime, using the Java interface our group in Oracle Labs contributed to that project.

In this tutorial we'll look at loading in models trained in XGBoost, scikit-learn and pytorch, all for MNIST and we'll deploy them next to a logistic regression model trained in Tribuo. We discuss using external TensorFlow models in the TensorFlow tutorial, as TensorFlow brings it's own complexities. Note these models all depend on native libraries, which are available for x86_64 platforms on Windows, Linux and macOS. Both ONNX Runtime and XGBoost support macOS arm64 (i.e., Apple Silicon Macs), but you'll need to compile those from source and add them to Tribuo's class path to make this tutorial run on that platform.

Feature names and feature indices

Most of the difficulty in loading in third party models comes in dealing with how the features are presented to the model and how the outputs are read back. If the host system thinks feature num_wheels has index 5, but the third party model expects feature num_wheels to have index 10 then the indices need to be transformed before the model can be used for inference. Similarly if the third party model assigns class car to index 0 and class bike to index 1, then the host system must know that mapping to return the correct outputs otherwise it could confuse the two classes.

Most ML libraries work purely with feature indices, so this task becomes a matter of ensuring the indices line up. However the model won't raise an exception if the indices don't line up, all it can check is if the number of features in an example matches the number it's expecting, it doesn't know if the indicies themselves line up with the right feature values. Tribuo avoids problems with feature indices (which can be particularly tricky in sparse problems like natural language processing) by naming all the features, the indices assigned to those names are an internal implementation detail which Tribuo users don't need to know about. Unfortunately when loading a third party model we require that the user tells us the mapping from Tribuo's feature names to the external model's feature indices, and also from output indices to output names or labels. Presenting this information to Tribuo is the tricky part of working with external models as it requires understanding how Tribuo's feature names are generated.


As usual we add some jar files to the classpath and import some classes from Tribuo and the JDK. We'll pull in the classification experiments jar as that contains XGBoost, and the ONNX jar containing Tribuo's interface to ONNX Runtime.

In [1]:
%jars tribuo-classification-experiments-4.1.0-jar-with-dependencies.jar
%jars tribuo-onnx-4.1.0-jar-with-dependencies.jar
In [2]:
import java.nio.file.Files;
import java.nio.file.Paths;

import org.tribuo.*;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.classification.xgboost.*;
import org.tribuo.common.xgboost.XGBoostExternalModel;
import org.tribuo.interop.onnx.*;
import org.tribuo.math.optimisers.AdaGrad;

import ai.onnxruntime.*;

We'll also need some data to work with, so we'll load in the MNIST train and test sets. We'll use Tribuo's built in IDXDataSource to read them, same as the configuration tutorial. If you've already downloaded MNIST then you can skip this step.

First download the training data:


Then the test data:


Tribuo's IDX loader natively reads gzipped files so you don't need to unzip them. Tribuo doesn't natively understand the 2d pixel arrangement, so the feature names from the IDXDataSource are just the integers 000 through 783, with leading zero padding to make it up to 3 digits. This will be important later when we look at scoring CNN based models.

In [3]:
var labelFactory = new LabelFactory();
var mnistTrainSource = new IDXDataSource<>(Paths.get("train-images-idx3-ubyte.gz"),Paths.get("train-labels-idx1-ubyte.gz"),labelFactory);
var mnistTestSource = new IDXDataSource<>(Paths.get("t10k-images-idx3-ubyte.gz"),Paths.get("t10k-labels-idx1-ubyte.gz"),labelFactory);
var mnistTrain = new MutableDataset<>(mnistTrainSource);
var mnistTest = new MutableDataset<>(mnistTestSource);
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));
Training data size = 60000, number of features = 717, number of classes = 10
Testing data size = 10000, number of features = 668, number of classes = 10

Building a Tribuo baseline

First we'll train a logistic regression in Tribuo, using AdaGrad for 3 epochs.

In [4]:
var lrTrainer = new LinearSGDTrainer(new LogMulticlass(),new AdaGrad(0.1),3,30000,Trainer.DEFAULT_SEED);
var lrModel = lrTrainer.train(mnistTrain);
var lrEvaluation = labelFactory.getEvaluator().evaluate(lrModel,mnistTest);
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         944          36          41       0.963       0.958       0.961
1                           1,135       1,121          14          75       0.988       0.937       0.962
2                           1,032         906         126          88       0.878       0.911       0.894
3                           1,010         781         229          44       0.773       0.947       0.851
4                             982         888          94          77       0.904       0.920       0.912
5                             892         770         122         149       0.863       0.838       0.850
6                             958         923          35         111       0.963       0.893       0.927
7                           1,028         954          74         112       0.928       0.895       0.911
8                             974         853         121         229       0.876       0.788       0.830
9                           1,009         852         157          82       0.844       0.912       0.877
Total                      10,000       8,992       1,008       1,008
Accuracy                                                                    0.899
Micro Average                                                               0.899       0.899       0.899
Macro Average                                                               0.898       0.900       0.898
Balanced Error Rate                                                         0.102

We get a baseline performance of 89% accuracy, which we could bump by tuning the various hyperparameters, but it'll do as a demonstration of training a native Tribuo model.

Loading in an XGBoost model

To load in a third party XGBoost model we use the XGBoostExternalModel class. Unlike Tribuo's training interface for XGBoost models which are specific to the prediction task, this class supports both classification and regression models, and the output type is encoded by the way it's constructed.

We can construct a XGBoostExternalModel from a saved xgb file on disk. If you've got an XGBoost4J Booster object in memory, then you'll need to write it out to disk first. We'll look at relaxing this restriction if this is a popular use case (but it's likely that if you've already got it in memory, then you could have trained it with Tribuo too). Also due to our use of XGBoost's internal deserialization mechanism, the file can't be gzipped or otherwise compressed (again, it's possible to relax this restriction if there is user interest).

As mentioned above, we're going to use a classification model, so first we need to instantiate an XGBoostClassificationConverter, which converts between Tribuo's label representation and XGBoost's label representation. This is usually hidden when working with XGBoost models trained in Tribuo, but as we can load external models of either classification or regression types we need to tell the system what kind of output it's working with.

In [5]:
var xgbLabelConv = new XGBoostClassificationConverter();
var xgbModelPath = Paths.get("external-models","xgb_mnist.xgb");

Now we come to the complicated bit, building the mapping from Tribuo's feature names to XGBoost's feature indices, and building the mapping from Tribuo's Labels to XGBoost's output indices. When training a model inside Tribuo this is all performed automatically, but when loading an XGBoost model trained in Python, R, or even XGBoost4J without Tribuo we need to provide that information.

In some cases this might be a completely trivial mapping, where Tribuo used String representations of integer feature numbers, and those feature numbers map directly to the ones used to train the external model. This can happen if you use libsvm formatted datasets which number their features, or if you've written your own DataSource which uses integers as the feature names.

As we've loaded MNIST from the original IDX format we've got numerically increasing ids, so to make it more interesting the XGBoost model was trained with inverted feature ids (i.e., pixel [0,0] was given id "783" in the training run). We didn't change the output mapping, so the label ids are simply the integer (i.e., Label "0" has id 0).

In [6]:
Map<String, Integer> xgbFeatMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
    // This MNIST model has the feature indices transposed to test a non-trivial mapping.
    int id = (783 - i);
    xgbFeatMapping.put(String.format("%03d", i), id);
Map<Label, Integer> xgbOutMapping = new HashMap<>();
for (Label l : mnistTrain.getOutputInfo().getDomain()) {
    xgbOutMapping.put(l, Integer.parseInt(l.getLabel()));
Model<Label> xgbModel = XGBoostExternalModel.createXGBoostModel(labelFactory, xgbFeatMapping, xgbOutMapping, xgbLabelConv, xgbModelPath);

Now we've got the XGBoost model loaded into Tribuo, we can evaluate it the same way we evaluated the native Tribuo logistic regression.

In [7]:
var xgbEvaluation = labelFactory.getEvaluator().evaluate(xgbModel,mnistTest);
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         963          17          24       0.983       0.976       0.979
1                           1,135       1,120          15          14       0.987       0.988       0.987
2                           1,032         989          43          32       0.958       0.969       0.963
3                           1,010         976          34          52       0.966       0.949       0.958
4                             982         945          37          32       0.962       0.967       0.965
5                             892         853          39          21       0.956       0.976       0.966
6                             958         928          30          25       0.969       0.974       0.971
7                           1,028         989          39          31       0.962       0.970       0.966
8                             974         929          45          50       0.954       0.949       0.951
9                           1,009         968          41          59       0.959       0.943       0.951
Total                      10,000       9,660         340         340
Accuracy                                                                    0.966
Micro Average                                                               0.966       0.966       0.966
Macro Average                                                               0.966       0.966       0.966
Balanced Error Rate                                                         0.034

We get 96% accuracy, as gradient boosted trees are a much more powerful model than logistic regression. We could get similar accuracy using XGBoost trained natively in Tribuo but that wouldn't explain much about loading in external models.

Loading in an ONNX model

ONNX models can encapsulate both neural network and more traditional ML models, and as such they have more complex inputs and outputs. An ONNX model can accept tensor inputs of arbitrary dimension, unlike XGBoost which only accepts feature vectors. Also different ONNX models accept different tensors even for the same task. For example when training an MNIST model in scikit-learn the model expects a feature vector [784], whereas if you train a CNN using pytorch on the same task it will expect a batch of inputs of size [batch_size,1,28,28], in the standard [batch,channels,height,width] format. This means we need different input preprocessing logic to ensure that the tensors are formatted appropriately. In Tribuo this is controlled by supplying the appropriate implementation of ExampleTransformer. We supply two implementations, one for generating feature vectors and one for generating 4d tensors (i.e., image batches). Similar to the XGBoost model above we also need a converter from the ONNX output format into a Tribuo Prediction, which is encapsulated in the OutputTransformer interface. We supply two implementations, one for classification and one for regression. The classification one is further specialised as it accepts both scikit-learn style outputs (which produce a map from id to probability) and pytorch style outputs (which return a float matrix). If your ONNX model produces other outputs then you'll need to write your own converter.

ONNX models also require a OrtSessionOptions which control how the ONNX model is scored. This is transient and needs to be set each time the ONNXExternalModel is loaded, as unfortunately the OrtSessionOptions object is not introspectable and can't be serialized. If it's not set then it defaults to single threaded CPU computation.

Deploying a scikit-learn model

First we'll look at running inference on a model trained in scikit-learn, using a DenseTransformer to process the input. We used the scikit-onnx converter to change the logistic regression we trained in scikit-learn into onnx format. This approach should apply to any model trained in scikit-learn for classification tasks.

In [8]:
var denseTransformer = new DenseTransformer();
var labelTransformer = new LabelTransformer();
var onnxSklPath = Paths.get("external-models","skl_lr_mnist.onnx");
var ortEnv = OrtEnvironment.getEnvironment();
var sessionOpts = new OrtSession.SessionOptions();

Again this model was trained in scikit-learn using an inverted feature mapping to make it a little more interesting. We could reuse the mapping from XGBoost, but we'll paste it in here again with appropriate variable names for clarity. One further thing to note is that ONNX models have named inputs and outputs, so we need to tell Tribuo what input should be supplied. Tribuo matches the outputs based on the type and number of the values returned, so supplying the appropriate OutputTransformer will be sufficient.

In [9]:
Map<String, Integer> sklFeatMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
    // This MNIST model has the feature indices transposed to test a non-trivial mapping.
    int id = (783 - i);
    sklFeatMapping.put(String.format("%03d", i), id);
Map<Label, Integer> sklOutMapping = new HashMap<>();
for (Label l : mnistTrain.getOutputInfo().getDomain()) {
    sklOutMapping.put(l, Integer.parseInt(l.getLabel()));
Model<Label> sklModel = ONNXExternalModel.createOnnxModel(labelFactory, sklFeatMapping, sklOutMapping, 
                    denseTransformer, labelTransformer, sessionOpts, onnxSklPath, "float_input");

Now we've got the scikit-learn model loaded into Tribuo, we can evaluate it the same way we evaluated the native Tribuo logistic regression.

In [10]:
var sklEvaluation = labelFactory.getEvaluator().evaluate(sklModel,mnistTest);
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         963          17          46       0.983       0.954       0.968
1                           1,135       1,112          23          37       0.980       0.968       0.974
2                           1,032         926         106          70       0.897       0.930       0.913
3                           1,010         916          94          98       0.907       0.903       0.905
4                             982         910          72          64       0.927       0.934       0.930
5                             892         776         116          83       0.870       0.903       0.886
6                             958         910          48          55       0.950       0.943       0.946
7                           1,028         951          77          70       0.925       0.931       0.928
8                             974         869         105         133       0.892       0.867       0.880
9                           1,009         922          87          89       0.914       0.912       0.913
Total                      10,000       9,255         745         745
Accuracy                                                                    0.926
Micro Average                                                               0.926       0.926       0.926
Macro Average                                                               0.924       0.925       0.924
Balanced Error Rate                                                         0.076

We get 92% accuracy for the logistic regression in scikit-learn. We used the default hyperparameters in scikit-learn to build the model and those hyperparameters perform a little better on MNIST than the default hyperparameters from Tribuo, but with sufficient tuning the models are equivalent.

Deploying a pytorch model

Now we'll run inference on a simple LeNet-style convolutional neural network (CNN) in pytorch, but this should apply to any onnx model which accepts images as a 4d tensor and produces a classification output.

It's much the same as the other two examples, the difficulty comes in ensuring that the feature and output indices are lined up correctly. As this pytorch model expects an image tensor input we'll use the ImageTransformer to convert the feature indices, telling it to expect an image with a single channel which is 28 pixels wide and 28 pixels high.

In [11]:
var imageTransformer = new ImageTransformer(1,28,28);
var onnxPyTorchPath = Paths.get("external-models","pytorch_cnn_mnist.onnx");

This time we're going to use the identity mapping, because this CNN was trained on the standard MNIST images without any transpositions. Remembering that Tribuo's IDXDataSource flattens the feature ids, we just map the integers 0 through 783 to themselves and use the ImageTransformer to deal with the reshape into a tensor.

In [12]:
Map<String, Integer> ptFeatMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
    ptFeatMapping.put(String.format("%03d", i), i);
Map<Label, Integer> ptOutMapping = new HashMap<>();
for (Label l : mnistTrain.getOutputInfo().getDomain()) {
    ptOutMapping.put(l, Integer.parseInt(l.getLabel()));
Model<Label> ptModel = ONNXExternalModel.createOnnxModel(labelFactory, ptFeatMapping, ptOutMapping, imageTransformer,
                    labelTransformer, sessionOpts, onnxPyTorchPath, "input_image");

Now we've got the pytorch model loaded into Tribuo, we can evaluate it the same way we evaluated the rest of the models.

In [13]:
var ptEvaluation = labelFactory.getEvaluator().evaluate(ptModel,mnistTest);
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         978           2          14       0.998       0.986       0.992
1                           1,135       1,134           1          14       0.999       0.988       0.993
2                           1,032       1,021          11           6       0.989       0.994       0.992
3                           1,010       1,001           9          10       0.991       0.990       0.991
4                             982         978           4           9       0.996       0.991       0.993
5                             892         884           8           9       0.991       0.990       0.990
6                             958         944          14           3       0.985       0.997       0.991
7                           1,028       1,014          14          16       0.986       0.984       0.985
8                             974         962          12           5       0.988       0.995       0.991
9                           1,009         988          21          10       0.979       0.990       0.985
Total                      10,000       9,904          96          96
Accuracy                                                                    0.990
Micro Average                                                               0.990       0.990       0.990
Macro Average                                                               0.990       0.990       0.990
Balanced Error Rate                                                         0.010

Unsurprisingly the LeNet style CNN performs the best on the MNIST test set giving 99% accuracy, beating both Tribuo's and scikit-learn's logistic regressions along with the XGBoost model. If you want to train CNNs and other deep learning models like MLPs in Tribuo, check out our TensorFlow support.


We saw how to load in externally trained models in multiple formats, and how to deploy those models alongside Tribuo's native models. We also looked at how ONNX models can accept different tensor shapes as inputs, and used Tribuo's mechanisms for converting an Example into either a vector or a tensor depending on if the external model expected a vector or an image as an input.

Given how useful the ONNX model import code is, allowing Tribuo to load in many different kinds of models trained in many different libraries, it's natural to ask what support Tribuo has for exporting ONNX models. At the moment we don't support exporting Tribuo's native models to ONNX format, but we're investigating how to do this purely from Java, and we hope to be able to do this in a future release.