Documentation

Introduction

Tribuo is a Java library for building and deploying Machine Learning models. The core development team is Oracle Labs' Machine Learning Research Group, and the library is available on Github under the Apache 2.0 license.

Tribuo has a modern Java-centric API design:

  • The API is strongly typed, with parameterised classes for models, predictions, datasets and examples.
  • The API is high level, Models consume Examples and produce Predictions, not float arrays.
  • The API is uniform, all our prediction types have the same (well-typed) API, and Tribuo's classes are parameterised by the prediction type (e.g., classification uses Label, regression uses Regressor).
  • The API is reusable, it's modular and packaged into small chunks so you only deploy what you need.

Tribuo has a breadth of ML algorithms and features under the same API:

  • Classification: linear models, SVMs, trees, ensembles, deep learning
  • Regression: linear models, penalised linear regression, SVMs, trees, ensembles, deep learning
  • Clustering: K-Means
  • Anomaly Detection: SVMs

We plan to increase the algorithms available over time, we're happy to accept community contributions, and the current roadmap is on Github.

Tribuo makes it straightforward to load datasets, train models, and evaluate models on test data. For example, this code trains a logistic regression model and evaluates it:

var trainSet = new MutableDataset<>(new LibSVMDataSource(Paths.get("train-data"),new LabelFactory()));
var model    = new LogisticRegressionTrainer().train(trainSet);
var testSet  = new LibSVMDataSource(Paths.get("test-data"),trainSet.getOutputFactory());
var eval     = new LabelEvaluator().evaluate(model,testSet);

Getting Started

To pull Tribuo into your project use these Maven co-ordinates:

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.0.2</version>
    <type>pom</type>
</dependency>
The tribuo-all module pulls in all of Tribuo, you can select the subset for your particular usecase later, it's all available as separate maven artifacts.

Here's a quick example showing how to build and evaluate a classification system. It has 4 steps:

  1. Load a dataset for classifying the species of Irises from a CSV.
  2. Split that dataset into training and testing datasets.
  3. Train two types models using different trainers.
  4. Use a model to make predictions on the test set, and evaluate it's performance on the whole test set.
// Load labelled iris data
var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
DataSource<Label> irisData =
        new CSVLoader<>(new LabelFactory()).loadDataSource(Paths.get("bezdekIris.data"),
                                     /* Output column   */ irisHeaders[4],
                                     /* Column headers  */ irisHeaders);

// Split iris data into training set (70%) and test set (30%)
var splitIrisData = new TrainTestSplitter<>(irisData,
                       /* Train fraction */ 0.7,
                             /* RNG seed */ 1L);
var trainData = new MutableDataset<>(splitIrisData.getTrain());
var testData = new MutableDataset<>(splitIrisData.getTest());

// We can train a decision tree
var cartTrainer = new CARTClassificationTrainer();
Model<Label> tree = cartTrainer.train(trainData);
// Or a logistic regression
var linearTrainer = new LogisticRegressionTrainer();
Model<Label> linear = linearTrainer.train(trainData);

// Finally we make predictions on unseen data
// Each prediction is a map from the output names (i.e. the labels) to the scores/probabilities
Prediction<Label> prediction = linear.predict(testData.getExample(0));

// Or we can evaluate the full test dataset, calculating the accuracy, F1 etc.
LabelEvaluation evaluation = new LabelEvaluator().evaluate(linear,testData);
// we can inspect the evaluation manually
double acc = evaluation.accuracy();
// which returns 0.978
// or print a formatted evaluation string
System.out.println(evaluation.toString());
The formatted evaluation output looks like this:

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

To learn more about this example, take a look at our Classification Tutorial using the same Iris dataset.

Documentation Overview

The Features List gives an overview of what you can do with Tribuo and the algorithms that it supports both natively and through interfaces to popular third-party libraries. The best way to understand Tribuo is to read through Tribuo's Architecture document. This covers some basic definitions, data flow, the library structure, configuration (including options and provenance), data loading, transformations, details about examples, and obfuscation features available to help mask your input features. The Package Structure overview describes how the packages in Tribuo are organized around the machine learning tasks that each one supports. These packages are grouped into modules so that users of Tribuo can depend only on the pieces they need in their implementations. Be sure to read up on the Security Considerations around using Tribuo and what the expectations are for its users. For more odds and ends and general questions, the FAQ is the place to look. For details on all the classes and packages, consult Tribuo's JavaDoc.

Tutorials

We have tutorial notebooks for Classification, Clustering, Regression, Anomaly Detection and the configuration system in tutorials. These use the IJava Jupyter notebook kernel, and work with Java 10+. It should be straight-forward to convert the code in the tutorials back to Java 8 code by replacing the var keyword with the appropriate types.

Configuration and Provenance

The trainers in Tribuo are fully configurable via the OLCUT configuration system. This allows a user to define a trainer in an XML (or JSON or EDN) file once and repeatably build models with exactly the same parameters. There are example configurations for each of the supplied Trainers in the config folder of each package. Models are serializable using Java serialization, as are the datasets themselves, and the configuration used is stored with any model. All models and evaluations include a serializable provenance object which records when the model or evaluation was created, what data was used, any transformations applied to the data, the hyperparameters of the trainer, and for evaluations, what model was used. This information can be extracted out into JSON, or can be serialised directly using Java serialisation. For production deployments this provenance information can be redacted and replaced with a hash to provide model tracking through an external system. Read more about Configuration, Options, and Provenance

Platform Support & Requirements

Tribuo runs on Java 8+, and we test on LTS versions of Java, along with the latest release. Tribuo itself is a Java library and supported on all Java platforms, however some of our interfaces require native code, and those are supported only where the native library is. We test on x86_64 architectures on Windows 10, macOS, and Linux (RHEL/OL/CentOS 7+), as these are supported platforms for the native libraries we interface with. If you're interested in another platform and wish to use one of the native library interfaces (ONNX Runtime, TensorFlow, and XGBoost) then we recommend reaching out to the developers of those libraries.