Multi-Label Classification Tutorial

This tutorial shows how to use Tribuo's MultiLabel package to perform multi-label classification tasks. Multi-label classification is the task of assigning a set of labels to a given example from a specific label domain, as opposed to multi-class classification which is assigning a single label to a given example.

Tribuo provides linear model and factorization machine algorithms for native multi-label prediction, along with ensemble methods that either predict each label independently or as part of a classifier chain, using any of Tribuo's classification algorithms as the base learners. Both the linear models, factorization machines and the IndependentMultiLabelTrainer use the Binary Relevance approach to multi-label prediction, where each label is predicted independently. The ClassifierChainTrainer and CCEnsembleTrainer use classifier chains which incorporate label structure into the prediction. In this tutorial we'll cover loading in multi-label data, performing predictions using several binary relevance based classifiers along with some classifier chains, and finally we'll evaluate the multi-label models using Tribuo's multi-label evaluation package.

Setup

First you'll need a copy of the multi-label yeast dataset (we'll download these from the LibSVM dataset repo):

wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_train.svm.bz2
wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_test.svm.bz2

Then you should extract it using your preferred method. On macOS and Linux you can use bunzip2, and on Windows there are several packages which can extract bz2 files (e.g., 7-zip).

This dataset has 14 labels which represent different functional groups and the task is to predict the functional groups a gene belongs in based on micro-array expression measurements. Fortunately we don't need a PhD in Genetics to use this dataset as a benchmark, though obviously domain knowledge would be critical if we wanted to actually deploy any model based on this data.

Now we'll load in the necessary jars and import some packages from the JDK and Tribuo.

In [1]:
%jars ./tribuo-multilabel-sgd-4.3.0-jar-with-dependencies.jar
%jars ./tribuo-classification-experiments-4.3.0-jar-with-dependencies.jar
In [2]:
import java.nio.file.Paths;
In [3]:
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.dtree.impurity.*;
import org.tribuo.datasource.*;
import org.tribuo.math.optimisers.*;
import org.tribuo.multilabel.*;
import org.tribuo.multilabel.baseline.*;
import org.tribuo.multilabel.ensemble.*;
import org.tribuo.multilabel.evaluation.*;
import org.tribuo.multilabel.sgd.linear.*;
import org.tribuo.multilabel.sgd.objectives.*;
import org.tribuo.util.Util;

Loading the data

There are two main forms for multi-label data in columnar representations. Either the dataset stores the labels in a single column using some delimiter (e.g., "first_label,third_label"), resulting in a sparse representation of the labels, or each label is stored in it's own column with a flag representing if that label is present (e.g., "TRUE" or "1"), resulting in a dense representation of the labels. Tribuo can load both formats, though currently the MultiLabelFactory only supports comma separated labels when parsing inputs directly from a String. When processing multi-label values through a RowProcessor then the factory receives a List<String> and processes each separate label appropriately.

The yeast dataset we downloaded is stored in libsvm format which uses a sparse representation of the labels, so we'll use Tribuo's LibSVMDataSource to load it in, and process the outputs through a MultiLabelFactory.

In [4]:
var factory = new MultiLabelFactory();
var trainSource = new LibSVMDataSource<>(Paths.get(".","yeast_train.svm"),factory);
var testSource = new LibSVMDataSource<>(Paths.get(".","yeast_test.svm"),factory,trainSource.isZeroIndexed(),trainSource.getMaxFeatureID());
var train = new MutableDataset<>(trainSource);
var test = new MutableDataset<>(testSource);
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",train.size(),train.getFeatureMap().size(),train.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",test.size(),test.getFeatureMap().size(),test.getOutputInfo().size()));
Training data size = 1500, number of features = 103, number of classes = 14
Testing data size = 917, number of features = 103, number of classes = 14

In Tribuo we represent a multi-label task using the org.tribuo.multilabel.MultiLabel output, which internally uses a set of org.tribuo.classification.Label objects to store the individual labels. This means that unlike most Tribuo prediction type packages, tribuo-multilabel-core depends on another output core package tribuo-classification-core. MultiLabel is a sparse representation of the labels, only the Labels which are active are stored in the MultiLabel object.

We can inspect the first output from the training dataset to see this:

In [5]:
System.out.println("First output = " + train.getExample(0).getOutput());
System.out.println("Second output = " + train.getExample(1).getOutput());
First output = (LabelSet={2,3})
Second output = (LabelSet={11,12,6,7})

This first example is tagged with labels 2 & 3, and the second one is tagged with 6, 7, 11 and 12. Unfortunately the LibSVM format we loaded in uses numbers rather than names for the labels, but if there are more descriptive names present when the data is loaded in then those would be used as the label names.

The task of multi-label prediction

Multi-label problems can be approached in several different ways. A common approach is to treat each label as an independent function of the input features, this leads to the binary relevance approach where each label is independent from each other, and multi-label classification can be thought of as a set of standard binary classification problems. This approach scales well, but if there is underlying structure in the label space (e.g., the label "human" implies the label "animal", but the label "animal" does not imply "human", so they are not independent), then this approach ignores useful information from the training data and may underperform more complicated approaches. Another popular way to convert a multi-label problem into a standard classification problem is via a label powerset, where each unique combination of the individual labels is treated as a single label in a large multi-class classification problem. While this allows the learning algorithm to fully capture any interactions between the labels, the label powerset is exponential in the number of labels, which rapidly makes this approach intractable as the number of labels increases though it can be useful in small label spaces. Tribuo currently focuses on binary relevance and other approaches which don't require exponential computation, though we're happy to discuss incorporating label powerset methods if people have need for them.

Training Binary Relevance models

Now we'll train a few different binary relevance models (i.e., independent predictions of each label). First we'll use Tribuo's multi-label LinearSGDModel which natively makes multi-label predictions, then we'll wrap a binary classification decision tree into a multi-label predictor using IndependentMultiLabelTrainer and IndependentMultiLabelModel. Note: Tribuo has three classes called LinearSGDModel, one each for Label, MultiLabel, and Regressor, so the LinearSGDModel used in this tutorial is org.tribuo.multilabel.sgd.linear.LinearSGDModel, and the one used in the multi-class classification tutorials is org.tribuo.classification.sgd.linear.LinearSGDModel.

Tribuo's multi-label SGD package supports two different objective functions, Binary Cross-Entropy and Hinge loss. The BCE loss produces probabilitistic outputs thresholded at 0.5, whereas the hinge loss produces scores thresholded at 0. As Tribuo usually produces scores for each possible label, these thresholds are used to determine when a particular label is present in the MultiLabel object representing the predicted label set. As you may have seen in other tutorials, Tribuo uses stochastic gradient descent to fit it's linear models, so we'll define a gradient optimizer along with the loss function.

In [6]:
var linTrainer = new LinearSGDTrainer(new BinaryCrossEntropy(),new AdaGrad(0.1,0.1),5,1000,1,Trainer.DEFAULT_SEED);

We train the model the same way we train the rest of Tribuo's models.

In [7]:
var linStartTime = System.currentTimeMillis();
var linModel = linTrainer.train(train);
var linEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Linear model training took " + Util.formatDuration(linStartTime,linEndTime));
Linear model training took (00:00:00:192)

Tribuo doesn't have a native implementation of a multi-label decision tree, but it does have a multi-class decision tree, which we can convert into a multi-label predictor using IndependentMultiLabelTrainer. Now let's train a model using a decision tree to predict each label independently. First we define the binary classification trainer, then we'll use IndependentMultiLabelTrainer to wrap that Trainer<Label> and convert it into a Trainer<MultiLabel>, before training as usual.

In [8]:
Trainer<Label> treeTrainer = new CARTClassificationTrainer(6,10,0.0f,1.0f,false,new Entropy(),1L);
Trainer<MultiLabel> dtTrainer = new IndependentMultiLabelTrainer(treeTrainer);
var dtStartTime = System.currentTimeMillis();
var dtModel = dtTrainer.train(train);
var dtEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Tree model training took " + Util.formatDuration(dtStartTime,dtEndTime));
Tree model training took (00:00:03:188)

We've now got two different models, so let's measure their performance.

Evaluating multi-class problems

Multi-label problems have many evaluation options available, as many standard classification evaluation measures like accuracy, precision, recall and F1 can be applied at the label level, and there are also many set level metric such as the Jaccard Index which can be used to compare the predicted label set against the ground truth one. In Tribuo we have access to most of the metrics available for multi-class classification problems, and v4.2 began adding set level metrics as well. If there are useful metrics that aren't implemented in Tribuo raise an issue on Tribuo's Github page.

If you want to use the predicted scores for each of the labels separately (e.g., to analyse the model's performance) then as usual the Map<String,MultiLabel> available from Prediction.getOutputScores() has the full distribution. This map behaves slightly counter-intuitively, as each value is a MultiLabel object containing a single Label, and the key is the output of Label.toString(). This allows the labels to be inspected individually, but it is a little uncomfortable if you're used to working with a multi-label specific API. However it maintains conformity across all of Tribuo's different prediction APIs, both for predictions and evaluations, which makes it easier to incorporate lots of ML models into a larger system.

We use the same evaluation paradigm as other Tribuo prediction tasks, first we construct an Evaluator and then feed it a model and some test data to produce an Evaluation which contains the appropriate performance metrics.

In [9]:
var eval = new MultiLabelEvaluator();

First we'll look at the linear model:

In [10]:
var linTStartTime = System.currentTimeMillis();
var linEval = eval.evaluate(linModel,test);
var linTEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Linear model evaluation took " + Util.formatDuration(linTStartTime,linTEndTime));
System.out.println(linEval);
Linear model evaluation took (00:00:00:063)
Class                           n          tp          fn          fp      recall        prec          f1
(LabelSet={12})               683         677           6         230       0.991       0.746       0.852
(LabelSet={13})                13           0          13           0       0.000       0.000       0.000
(LabelSet={0})                286         131         155          45       0.458       0.744       0.567
(LabelSet={1})                393         162         231         130       0.412       0.555       0.473
(LabelSet={2})                385         231         154          95       0.600       0.709       0.650
(LabelSet={3})                330         169         161          79       0.512       0.681       0.585
(LabelSet={4})                281         106         175          35       0.377       0.752       0.502
(LabelSet={5})                219          26         193          14       0.119       0.650       0.201
(LabelSet={6})                167           0         167           0       0.000       0.000       0.000
(LabelSet={7})                191           0         191           0       0.000       0.000       0.000
(LabelSet={8})                 80           0          80           0       0.000       0.000       0.000
(LabelSet={9})                 92           0          92           0       0.000       0.000       0.000
(LabelSet={10})                91           0          91           0       0.000       0.000       0.000
(LabelSet={11})               688         684           4         226       0.994       0.752       0.856
Total                       3,899       2,186       1,713         854
Accuracy                                                                    0.561
Micro Average                                                               0.561       0.719       0.630
Macro Average                                                               0.319       0.399       0.335
Balanced Error Rate                                                         0.681
Jaccard Score                                                               0.497

Next, the decision tree:

In [11]:
var dtTStartTime = System.currentTimeMillis();
var dtEval = eval.evaluate(dtModel,test);
var dtTEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Tree model evaluation took " + Util.formatDuration(dtTStartTime,dtTEndTime));
System.out.println(dtEval);
Tree model evaluation took (00:00:00:094)
Class                           n          tp          fn          fp      recall        prec          f1
(LabelSet={12})               683         607          76         201       0.889       0.751       0.814
(LabelSet={13})                13           0          13           2       0.000       0.000       0.000
(LabelSet={0})                286         111         175          98       0.388       0.531       0.448
(LabelSet={1})                393         187         206         181       0.476       0.508       0.491
(LabelSet={2})                385         251         134         193       0.652       0.565       0.606
(LabelSet={3})                330         131         199          67       0.397       0.662       0.496
(LabelSet={4})                281          92         189          41       0.327       0.692       0.444
(LabelSet={5})                219          88         131         189       0.402       0.318       0.355
(LabelSet={6})                167          30         137          41       0.180       0.423       0.252
(LabelSet={7})                191          29         162          46       0.152       0.387       0.218
(LabelSet={8})                 80           1          79          12       0.013       0.077       0.022
(LabelSet={9})                 92          14          78          35       0.152       0.286       0.199
(LabelSet={10})                91          20          71          83       0.220       0.194       0.206
(LabelSet={11})               688         633          55         202       0.920       0.758       0.831
Total                       3,899       2,194       1,705       1,391
Accuracy                                                                    0.563
Micro Average                                                               0.563       0.612       0.586
Macro Average                                                               0.369       0.439       0.384
Balanced Error Rate                                                         0.631
Jaccard Score                                                               0.439

We can see the native multi-label linear model outperformed the wrapped decision tree in terms of Jaccard Score, though the picture is more mixed in the other metrics, and the linear model is ignoring some of the labels.

Unfortunately some of the metrics we might like to examine for regular multi-class classification aren't as easy to use in the multi-label case. For example, a multi-class confusion matrix has no direct analogue in the multi-label case, as there could be an arbitrary number of labels predicted for each output, meaning there is no notion of a label being mispredicted as another label. This means a multi-label confusion matrix is best presented as a series of binary confusion matrices, one per label. This tends to take up a lot of space, so we'll skip inspecting them in this tutorial, though they are accessible on the MultiLabelEvaluation object.

Now let's look at a more complicated multi-label classification approach, using Classifier Chains.

Training Classifier Chains

A classifier chain is similar to a binary relevance model, except there is a sequential order to the predictions (forming a chain), and each member of the chain receives extra features in the form of the predictions of earlier members of the chain. This means that if the chain is correctly ordered according to the causal structure of the labels (which is tricky to do) then it can start with the most independent label first, and then predict each label in sequence so it can use the earlier predictions to improve predictions for each subsequent label (e.g., we could predict if the example is an "animal" first, and then when we come to predict if it's a "human" we know that humans are animals making the prediction task easier).

In practice we don't usually know the correct ordering of the labels as the causal structure is unknown, and if we supply the incorrect structure then we can reduce performance back to the level of the binary relevance models. Fortunately in Machine Learning we have a trick we can use when we need to deal with uncertain data, which is to randomize it many times, and take an average. So we could take many different classifier chains each with an random label order, and then each chain votes on the labels that should be predicted. This improves statistical performance over a single chain with a random order, and over a single chain with a poorly chosen order, though it's unlikely to beat a single classifier chain with the correct label ordering (if such an ordering exists). Unfortunately the classifier chain ensemble is more expensive computationally than the single chain, which is already relatively expensive compared to a single classifier like LinearSGDModel, but the chains can be straightforwardly parallelized (and we'll add support for this to a future version of Tribuo).

We're going to use a single classifier chain with a random order, and then an ensemble of 20 classifier chains each using random orders to see how they perform.

In [12]:
var ccTrainer = new ClassifierChainTrainer(treeTrainer,1L);
var ccEnsembleTrainer = new CCEnsembleTrainer(treeTrainer,20,1L);

First we'll train and evaluate the single chain:

In [13]:
// train the model
var ccStartTime = System.currentTimeMillis();
var ccModel = ccTrainer.train(train);
var ccEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Classifier Chain model training took " + Util.formatDuration(ccStartTime,ccEndTime));

// evaluate the model
var ccTStartTime = System.currentTimeMillis();
var ccEval = eval.evaluate(ccModel,test);
var ccTEndTime = System.currentTimeMillis();
System.out.println("Classifier Chain model evaluation took " + Util.formatDuration(ccTStartTime,ccTEndTime));
System.out.println(ccEval);
Classifier Chain model training took (00:00:02:893)
Classifier Chain model evaluation took (00:00:00:153)
Class                           n          tp          fn          fp      recall        prec          f1
(LabelSet={12})               683         616          67         203       0.902       0.752       0.820
(LabelSet={13})                13           0          13           2       0.000       0.000       0.000
(LabelSet={0})                286         159         127         172       0.556       0.480       0.515
(LabelSet={1})                393         215         178         213       0.547       0.502       0.524
(LabelSet={2})                385         251         134         193       0.652       0.565       0.606
(LabelSet={3})                330         199         131         151       0.603       0.569       0.585
(LabelSet={4})                281         112         169         124       0.399       0.475       0.433
(LabelSet={5})                219          74         145         116       0.338       0.389       0.362
(LabelSet={6})                167          39         128          48       0.234       0.448       0.307
(LabelSet={7})                191          40         151          51       0.209       0.440       0.284
(LabelSet={8})                 80           0          80          14       0.000       0.000       0.000
(LabelSet={9})                 92           7          85          30       0.076       0.189       0.109
(LabelSet={10})                91           8          83          32       0.088       0.200       0.122
(LabelSet={11})               688         615          73         192       0.894       0.762       0.823
Total                       3,899       2,335       1,564       1,541
Accuracy                                                                    0.599
Micro Average                                                               0.599       0.602       0.601
Macro Average                                                               0.393       0.412       0.392
Balanced Error Rate                                                         0.607
Jaccard Score                                                               0.473

We can see the classifier chain improved over the binary relevance model when using trees as the base learner, and took roughly the same amount of time to train and evaluate. It's still not quite up to the linear model, but let's try the chain ensemble and see how it does.

Now we'll train and evaluate the ensemble:

In [14]:
// train the model
var ccEnsembleStartTime = System.currentTimeMillis();
var ccEnsembleModel = ccEnsembleTrainer.train(train);
var ccEnsembleEndTime = System.currentTimeMillis();
System.out.println();
System.out.println("Classifier Chain Ensemble model training took " + Util.formatDuration(ccEnsembleStartTime,ccEnsembleEndTime));

// evaluate the model
var ccETStartTime = System.currentTimeMillis();
var ccEnsembleEval = eval.evaluate(ccEnsembleModel,test);
var ccETEndTime = System.currentTimeMillis();
System.out.println("Classifier Chain Ensemble model evaluation took " + Util.formatDuration(ccETStartTime,ccETEndTime));
System.out.println(ccEnsembleEval);
Classifier Chain Ensemble model training took (00:00:54:230)
Classifier Chain Ensemble model evaluation took (00:00:02:249)
Class                           n          tp          fn          fp      recall        prec          f1
(LabelSet={12})               683         629          54         216       0.921       0.744       0.823
(LabelSet={13})                13           0          13           1       0.000       0.000       0.000
(LabelSet={0})                286         112         174          64       0.392       0.636       0.485
(LabelSet={1})                393         170         223         140       0.433       0.548       0.484
(LabelSet={2})                385         254         131         146       0.660       0.635       0.647
(LabelSet={3})                330         194         136         111       0.588       0.636       0.611
(LabelSet={4})                281         112         169          45       0.399       0.713       0.511
(LabelSet={5})                219          49         170          40       0.224       0.551       0.318
(LabelSet={6})                167          14         153           6       0.084       0.700       0.150
(LabelSet={7})                191          14         177          19       0.073       0.424       0.125
(LabelSet={8})                 80           0          80           0       0.000       0.000       0.000
(LabelSet={9})                 92           0          92           4       0.000       0.000       0.000
(LabelSet={10})                91           2          89           2       0.022       0.500       0.042
(LabelSet={11})               688         640          48         205       0.930       0.757       0.835
Total                       3,899       2,190       1,709         999
Accuracy                                                                    0.562
Micro Average                                                               0.562       0.687       0.618
Macro Average                                                               0.337       0.489       0.359
Balanced Error Rate                                                         0.663
Jaccard Score                                                               0.485

As expected the classifier chain ensemble outperformed the binary relevance model and the single classifier chain when using trees as the base learner, at the cost of the greatest runtime. It did this by significantly decreasing the number of false positives, at the cost of a small increase in false negatives. We didn't quite beat the performance of the linear model in terms of Jaccard score, but in general classifier chains are a powerful multi-label approach, and we could always use the linear model as a base learner (and if you do, then you do improve the Jaccard score above 0.497). We leave the implementation of that as an exercise for the reader.

Conclusion

We looked at Tribuo's multi-label classification package, trying out several different models using different approaches to the multi-label problem, namely binary relevance models and classifier chains. We're interested in expanding Tribuo's support for multi-label problems, so if there are algorithms or metrics Tribuo is missing head over to our Github page and contributions are always welcome.