Clustering Tutorial

This guide will show how to use Tribuo’s clustering models to find clusters in a toy dataset drawn from a mixture of Gaussians. We'll look at Tribuo's K-Means implementation and also discuss how evaluation works for clustering tasks.


We'll load in some jars and import a few packages.

In [1]:
%jars ./tribuo-clustering-kmeans-4.0.2-jar-with-dependencies.jar
In [2]:
import org.tribuo.*;
import org.tribuo.util.Util;
import org.tribuo.clustering.*;
import org.tribuo.clustering.evaluation.*;
import org.tribuo.clustering.example.ClusteringDataGenerator;
import org.tribuo.clustering.kmeans.*;
import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
In [3]:
var eval = new ClusteringEvaluator();


Tribuo's clustering package comes with a simple data generator that emits data sampled from a mixture of 5 2-dimensional Gaussians (the centroids and variances are fixed). This generator gives the ground truth cluster IDs, so it can be used for demos like this. You can also use any of the standard data loaders to pull in clustering data.

As it conforms to the standard Trainer and Model interface used for the rest of Tribuo, the training of a clustering algorithm doesn't produce cluster assignments that are visible, to recover the assignments we need to call model.predict(trainData).

We're going to sample two datasets (using different seeds) one for fitting the cluster centroids, and one to measure clustering performance.

In [4]:
var data = ClusteringDataGenerator.gaussianClusters(500, 1L);
var test = ClusteringDataGenerator.gaussianClusters(500, 2L);

The data generator uses the following Gaussians:

  1. N([ 0.0,0.0], [[1.0,0.0],[0.0,1.0]])
  2. N([ 5.0,5.0], [[1.0,0.0],[0.0,1.0]])
  3. N([ 2.5,2.5], [[1.0,0.5],[0.5,1.0]])
  4. N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])
  5. N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])

Model Training

We'll first fit a K-Means using 5 centroids, a maximum of 10 iterations, using the euclidean distance and a single computation thread.

In [5]:
var trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN,1,1);
var startTime = System.currentTimeMillis();
var model = trainer.train(data);
var endTime = System.currentTimeMillis();
System.out.println("Training with 5 clusters took " + Util.formatDuration(startTime,endTime));
Training with 5 clusters took (00:00:00:076)

We can inspect the centroids by querying the model.

In [6]:
var centroids = model.getCentroidVectors();
for (var centroid : centroids) {

These centroids line up pretty well with the Gaussian centroids. The predicted cluster ids line up with the true ids as follows (note that because the task is unsupervised the id numbers themselves are irrelevant, it's how they line up that matters):

Predicted True
1 5
2 3
3 1
4 2
5 4

Though the first one is a bit far out as it's x_1 should be -1.0 not -1.7, and there is a little wobble in the rest. Still it's pretty good considering K-Means assumes spherical gaussians and our data generator has a covariance matrix per gaussian.

Model evaluation

Tribuo uses the normalized mutual information to measure the quality of two clusterings. This avoids the issue that swapping the id number of any given centroid doesn't change the overall clustering. We're going to compare against the ground truth cluster labels from the data generator.

First for the training data:

In [7]:
var trainEvaluation = eval.evaluate(model,data);
Clustering Evaluation
Normalized MI = 0.8128096132028937
Adjusted MI = 0.8113314999600718

Then for the unseen test data:

In [8]:
var testEvaluation = eval.evaluate(model,test);
Clustering Evaluation
Normalized MI = 0.8154291916732408
Adjusted MI = 0.8139169342020222

We see that as expected it's a pretty good correlation to the ground truth labels. K-Means (of the kind implemented in Tribuo) is similar to a gaussian mixture using spherical gaussians, and our data generator uses gaussians with full rank covariances, so it won't be perfect.


Tribuo's K-Means supports multi-threading of both the expectation and maximisation steps in the algorithm (i.e., the finding of the new centroids, and the assignment of points to centroids). We'll run the same experiment as before, both with 5 centroids and with 20 centroids, using 4 threads, though this time we'll use 2000 points for training.

In [9]:
var mtData = ClusteringDataGenerator.gaussianClusters(2000, 1L);
var mtTrainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN,4,1);
var mtStartTime = System.currentTimeMillis();
var mtModel = mtTrainer.train(mtData);
var mtEndTime = System.currentTimeMillis();
System.out.println("Training with 5 clusters on 4 threads took " + Util.formatDuration(mtStartTime,mtEndTime));
Training with 5 clusters on 4 threads took (00:00:00:061)

Now with 20 centroids:

In [10]:
var overTrainer = new KMeansTrainer(20,10,Distance.EUCLIDEAN,4,1);
var overStartTime = System.currentTimeMillis();
var overModel = overTrainer.train(mtData);
var overEndTime = System.currentTimeMillis();
System.out.println("Training with 20 clusters on 4 threads took " + Util.formatDuration(overStartTime,overEndTime));
Training with 20 clusters on 4 threads took (00:00:00:054)

We can evaluate the two models as before, using our ClusteringEvaluator. First with 5 centroids:

In [11]:
var mtTestEvaluation = eval.evaluate(mtModel,test);
Clustering Evaluation
Normalized MI = 0.8104463467727057
Adjusted MI = 0.8088941747451207

Then with 20:

In [12]:
var overTestEvaluation = eval.evaluate(overModel,test);
Clustering Evaluation
Normalized MI = 0.8647317143685641
Adjusted MI = 0.860327445295668

We see that the multi-threaded versions run in less time than the single threaded trainer, despite them using 4 times the training data. The 20 centroid model has a tighter fit of the test data, despite being overparameterised. This is common in clustering tasks where it's hard to balance the model fitting with complexity. We'll look at adding more performance metrics so users can diagnose such issues in future releases.


We looked at clustering using Tribuo's K-Means implementation, comparing both the single-threaded and multi-threaded versions, then looked at the performance metrics available when there are ground truth clusterings.

We plan to further expand Tribuo's clustering functionality to incorporate other algorithms in the future. If you want to help, or have specific algorithmic requirements, file an issue on our github page.