Feature Selection

When working with tabular machine learning datasets it is quite common to collect as many columns as possible, and extract many features from those columns. This can lead to very large feature spaces, which can cause issues when fitting certain kinds of ML models, as they may use irrelevant or noisy features leading to poor test time performance. Additionally, large feature spaces can increase training times, and make it difficult to build interpretable models as there are too many features to reason over. For these reasons the field of feature selection developed, which aims to find the relevant subset of features for a given prediction task from a large set of possible features.

Tribuo v4.3 introduces feature selection algorithms for classification problems which score feature quality (called relevancy), and minimize redundant information in the feature set, using measures from information theory. It adds the core interfaces and classes for working with selected feature sets, along with classification specific implementations. In the future we may add feature selection algorithms for the other supported prediction types in Tribuo.

This tutorial will cover applying feature selection algorithms to a dataset, showing how reducing the feature space affects classification performance in terms of training time, accuracy and model size.

Setup

As usual we'll add some jars to the classpath and import some classes from Tribuo and the JDK. We're going to use the feature selection and SGD classification jars.

In [1]:
%jars tribuo-classification-sgd-4.3.0-jar-with-dependencies.jar
%jars tribuo-classification-fs-4.3.0-jar-with-dependencies.jar
In [2]:
import java.nio.file.Files;
import java.nio.file.Paths;

import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import org.tribuo.*;
import org.tribuo.dataset.SelectedFeatureDataset;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.fs.*;
import org.tribuo.classification.sgd.fm.*;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.transform.*;
import org.tribuo.transform.transformations.LinearScalingTransformation;
import org.tribuo.util.Util;

We'll also need some data to work with, so we'll load in the MNIST train and test sets. We'll use Tribuo's built in IDXDataSource to read them, same as the configuration tutorial. If you've already downloaded MNIST then you can skip this step.

First download the training data:

wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz; wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Then the test data:

wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz; wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

Tribuo's IDX loader natively reads gzipped files so you don't need to unzip them. Tribuo doesn't natively understand the 2d pixel arrangement, so the feature names from the IDXDataSource are just the integers 000 through 783, with leading zero padding to make it up to 3 digits.

In [3]:
var labelFactory = new LabelFactory();
var mnistTrainSource = new IDXDataSource<>(Paths.get("train-images-idx3-ubyte.gz"),Paths.get("train-labels-idx1-ubyte.gz"),labelFactory);
var mnistTestSource = new IDXDataSource<>(Paths.get("t10k-images-idx3-ubyte.gz"),Paths.get("t10k-labels-idx1-ubyte.gz"),labelFactory);
var mnistTrain = new MutableDataset<>(mnistTrainSource);
var mnistTest = new MutableDataset<>(mnistTestSource);
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));
Training data size = 60000, number of features = 717, number of classes = 10
Testing data size = 10000, number of features = 668, number of classes = 10

Building a baseline

First we'll train a model using all the features. We're going to use a small factorization machine as it's a high quality predictor, but still fairly time consuming to train. We're also going to wrap it in a transform trainer, as the factorization machine is sensitive to feature ranges and so is much happier if all the features are in the range [0,1).

In [4]:
var fmTrainer = new FMClassificationTrainer(new LogMulticlass(),  // Loss function
                                            new AdaGrad(0.1,0.1), // Gradient optimiser
                                            3,                    // Number of training epochs
                                            30000,                // Logging interval
                                            Trainer.DEFAULT_SEED, // RNG seed
                                            5,                    // Factor size
                                            0.1                   // Factor initialisation variance
                                            );
var transformations = new TransformationMap(List.of(new LinearScalingTransformation(0,1)));
var trainer = new TransformTrainer(fmTrainer,transformations);

var fmStartTime = System.currentTimeMillis();
var fmModel = trainer.train(mnistTrain);
var fmEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine on 783 features took " + Util.formatDuration(fmStartTime,fmEndTime));
Training factorization machine on 783 features took (00:00:31:395)

We can measure the accuracy of the full feature set:

In [5]:
var fmEvaluation = labelFactory.getEvaluator().evaluate(fmModel,mnistTest);
fmEvaluation.toString();
Out[5]:
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         962          18          35       0.982       0.965       0.973
1                           1,135       1,118          17          11       0.985       0.990       0.988
2                           1,032         992          40          63       0.961       0.940       0.951
3                           1,010         969          41          48       0.959       0.953       0.956
4                             982         937          45          41       0.954       0.958       0.956
5                             892         861          31          46       0.965       0.949       0.957
6                             958         920          38          24       0.960       0.975       0.967
7                           1,028         980          48          34       0.953       0.966       0.960
8                             974         919          55          38       0.944       0.960       0.952
9                           1,009         954          55          48       0.945       0.952       0.949
Total                      10,000       9,612         388         388
Accuracy                                                                    0.961
Micro Average                                                               0.961       0.961       0.961
Macro Average                                                               0.961       0.961       0.961
Balanced Error Rate                                                         0.039

So we get 96% accuracy, which is pretty good for a simple non-convolutional model on MNIST, but it took a little while to train, and we depend on all 783 features, so need ((783 5) 10) parameters for the feature embeddings and 783 * 10 parameters for the linear part of the model. Let's try to reduce the model complexity a little bit by selecting only the most important features.

Simple feature selection

We'll start by selecting features based on their mutual information with the class label. This measures how predictive each feature value is of the label on its own. We're going to use at most 100 features for the experiments from this point, so we'll tell each feature selection algorithm to stop after picking 100 features. Each of the algorithms we'll discuss in this tutorial needs the data to be discretised, and so we'll split each feature into 5 equal width bins. We use the MIM algorithm to select features:

In [6]:
var mim = new MIM(100,5);

var mimSelectStartTime = System.currentTimeMillis();
var mimSet = mim.select(mnistTrain);
var mimSelectEndTime = System.currentTimeMillis();
System.out.println("Selecting the top 100 features with MIM took " + Util.formatDuration(mimSelectStartTime,mimSelectEndTime));
Selecting the top 100 features with MIM took (00:00:02:861)

That was pretty quick, but it's also only a single calculation per feature. The select(Dataset<Label>) call returns a SelectedFeatureSet object which contains the feature names, scores, if the set is ordered or not (and all current implementations are), and the provenance for the selection procedure. The provenance can be serialized to a file or a byte array in the same way that other Tribuo objects can be. Let's briefly look at the provenance recorded in the feature set:

In [7]:
var setProvenance = mimSet.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(setProvenance));
SelectedFeatureSet(
	class-name = org.tribuo.SelectedFeatureSet
	dataset = MutableDataset(
			class-name = org.tribuo.MutableDataset
			datasource = IDXDataSource(
					class-name = org.tribuo.datasource.IDXDataSource
					outputPath = /local/ExternalRepositories/tribuo/tutorials/train-labels-idx1-ubyte.gz
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					featuresPath = /local/ExternalRepositories/tribuo/tutorials/train-images-idx3-ubyte.gz
					features-file-modified-time = 2000-07-21T14:20:24-04:00
					output-resource-hash = 3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C
					datasource-creation-time = 2022-10-07T13:13:41.469063236-04:00
					output-file-modified-time = 2000-07-21T14:20:27-04:00
					idx-feature-type = UBYTE
					features-resource-hash = 440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609
					host-short-name = DataSource
				)
			transformations = List[]
			is-sequence = false
			is-dense = false
			num-examples = 60000
			num-features = 717
			num-outputs = 10
			tribuo-version = 4.3.0
		)
	feature-selector = MIM(
			class-name = org.tribuo.classification.fs.MIM
			numBins = 5
			k = 100
			host-short-name = FeatureSelector
		)
	tribuo-version = 4.3.0
)

We can see the feature selection algorithm is recorded, the number of features that were selected, the number of discretisation bins, and the number of computation threads. We can also see the information from the dataset that the features were selected from, showing the usual information like the location on disk, hashes and timestamps. This provenance is recorded in the SelectedFeatureDataset used to train the models, and is in turn captured in the ModelProvenance for each of the models trained on those sets.

Now let's actually look at the selected features:

In [8]:
System.out.println("MIM feature set: " + mimSet.featureNames());
MIM feature set: [378, 406, 350, 434, 461, 433, 409, 377, 462, 568, 489, 405, 596, 542, 569, 437, 597, 373, 401, 428, 155, 541, 436, 400, 351, 567, 429, 381, 540, 345, 514, 543, 154, 515, 379, 488, 464, 539, 460, 156, 318, 456, 346, 570, 372, 374, 290, 595, 317, 457, 625, 408, 516, 323, 513, 490, 487, 375, 376, 512, 402, 153, 427, 455, 517, 430, 407, 511, 263, 484, 347, 626, 523, 486, 291, 485, 483, 458, 598, 354, 656, 349, 435, 624, 344, 322, 183, 463, 459, 399, 326, 432, 382, 465, 550, 655, 496, 571, 551, 657]

Most of the features are roughly in the middle of the 28x28 pixel grid in MNIST. This seems pretty sensible, there's little information in the pixels around the edges as they aren't used very often.

To actually use the feature set we build a SelectedFeatureDataset which copies an existing dataset (in this case the MutableDataset containing the MNIST training set), and removes all the unselected features. It has a parameter which lets the user control how many of the features in the selected feature set are used, which makes it easier to optimize over the feature set returned by a selection algorithm.

So now we'll build a smaller dataset with only the top 100 features in, and see how that affects the factorization machine.

In [9]:
var mimData = new SelectedFeatureDataset(mnistTrain, mimSet);
var mimStartTime = System.currentTimeMillis();
var mimModel = trainer.train(mimData);
var mimEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine on 100 features took " + Util.formatDuration(mimStartTime,mimEndTime));
Training factorization machine on 100 features took (00:00:12:582)

That took roughly a third of the time of training the full model, and is a seventh the size of the full model, but we'd expect some reduction in accuracy. So let's check that:

In [10]:
var mimEvaluation = labelFactory.getEvaluator().evaluate(mimModel,mnistTest);
mimEvaluation.toString();
Out[10]:
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         951          29          49       0.970       0.951       0.961
1                           1,135       1,119          16          40       0.986       0.965       0.976
2                           1,032         943          89          62       0.914       0.938       0.926
3                           1,010         884         126          76       0.875       0.921       0.897
4                             982         886          96         211       0.902       0.808       0.852
5                             892         805          87         158       0.902       0.836       0.868
6                             958         907          51          64       0.947       0.934       0.940
7                           1,028         952          76          49       0.926       0.951       0.938
8                             974         883          91          54       0.907       0.942       0.924
9                           1,009         800         209         107       0.793       0.882       0.835
Total                      10,000       9,130         870         870
Accuracy                                                                    0.913
Micro Average                                                               0.913       0.913       0.913
Macro Average                                                               0.912       0.913       0.912
Balanced Error Rate                                                         0.088

The smaller model gets roughly 91% accuracy, and has a much lower F1 for the classes "4" and "9". Those digits can look quite similar, and when you throw away many of the pixels you might miss the crucial ones which distinguish those digits.

The MIM algorithm can be quite prone to this, as the features it selects ignore redundancy and complementarity. Redundancy is when two features provide the same information, only one of those features is necessary to convey that information to the classifier. Complementarity is when two features combine to provide more information than the sum of their individual informations. This can occur when the information in one feature is conditional on the presence of another one.

So let's look at using a more complicated algorithm, JMI. As before we build a JMI object, and ask it to select 100 features. The JMI implementation in Tribuo is multi-threaded, so we'll also tell it to use 4 compute threads.

In [11]:
var jmi = new JMI(100,5,4);

var jmiSelectStartTime = System.currentTimeMillis();
var jmiSet = jmi.select(mnistTrain);
var jmiSelectEndTime = System.currentTimeMillis();
System.out.println("Selecting the top 100 features with JMI took " + Util.formatDuration(jmiSelectStartTime,jmiSelectEndTime));
System.out.println("JMI feature set: " + jmiSet.featureNames());
Selecting the top 100 features with JMI took (00:01:33:192)
JMI feature set: [378, 461, 409, 568, 350, 434, 542, 406, 489, 596, 401, 381, 433, 377, 569, 462, 437, 514, 405, 155, 428, 597, 436, 373, 515, 351, 541, 543, 429, 460, 154, 488, 625, 400, 464, 567, 374, 379, 570, 375, 345, 540, 487, 456, 376, 346, 408, 490, 457, 318, 156, 516, 539, 290, 513, 459, 372, 595, 153, 486, 402, 323, 354, 347, 430, 626, 517, 458, 317, 432, 326, 407, 512, 427, 656, 349, 485, 404, 455, 263, 624, 353, 523, 598, 484, 403, 463, 571, 382, 511, 322, 291, 183, 435, 655, 544, 431, 483, 465, 410]

The first feature selected by JMI is the same as the one from MIM, because they both pick the feature with the highest mutual information. After that point the feature sets diverge, as JMI is selecting features which combine well with the already selected features, in terms of maximising relevancy and complementarity while minimising redundancy, whereas MIM selects features solely based on relevancy. Let's see how that changes the performance of the factorization machine:

In [12]:
var jmiData = new SelectedFeatureDataset(mnistTrain, jmiSet);
var jmiStartTime = System.currentTimeMillis();
var jmiModel = trainer.train(jmiData);
var jmiEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine on 100 features took " + Util.formatDuration(jmiStartTime,jmiEndTime));
Training factorization machine on 100 features took (00:00:12:409)

As before, training on 100 features takes about a third of the time as the full model, and its still got the same size benefit. Let's look at the accuracy:

In [13]:
var jmiEvaluation = labelFactory.getEvaluator().evaluate(jmiModel,mnistTest);
jmiEvaluation.toString();
Out[13]:
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         950          30          57       0.969       0.943       0.956
1                           1,135       1,116          19          42       0.983       0.964       0.973
2                           1,032         948          84          78       0.919       0.924       0.921
3                           1,010         934          76         107       0.925       0.897       0.911
4                             982         808         174         109       0.823       0.881       0.851
5                             892         765         127          76       0.858       0.910       0.883
6                             958         884          74          50       0.923       0.946       0.934
7                           1,028         972          56          76       0.946       0.927       0.936
8                             974         881          93          57       0.905       0.939       0.922
9                           1,009         891         118         199       0.883       0.817       0.849
Total                      10,000       9,149         851         851
Accuracy                                                                    0.915
Micro Average                                                               0.915       0.915       0.915
Macro Average                                                               0.913       0.915       0.914
Balanced Error Rate                                                         0.087

Performance is a little better, and it's about 1 point of F1 better at predicting "9"s. We could further tune the performance by optimizing over the number of features to find a model which retained most of the performance while still shrinking in size (which usually makes deployment simpler & prediction faster).

We've been using Tribuo's sparsity features implicitly during this exploration, as each time we passed in the full test dataset containing all 784 features. Tribuo's models discard all feature names that weren't present at training time, so each model automatically discarded all the features that weren't in the 100 features it was trained on. This makes it much easier to do feature selection experiments than other systems where you'd need to explicitly subselect the test time dataset and repack things into the test time arrays/tensors.

Feature selection for data exploration

We've looked at using feature selection to improve training time and reduce model complexity, but it can also be used without building a model as part of a data exploration procedure. We've selected the most important 100 features for making predictions on the MNIST data, which for MNIST isn't particularly useful as we collect all the pixel values at once. However if we're working on tabular problems where each column of data has a particular cost of collection or storage, or has different privacy/governance requirements, then discovering which features we need to collect can either make our data pipelines much faster/cheaper to implement. We can also try to discover why these features are relevant for predictions using knowledge of the problem domain, and this might provide some insight into the prediction problem itself.

Conclusions

We discussed the process of feature selection, and ran a few algorithms on MNIST showing the tradeoffs between feature set size, accuracy and training speed. We also looked at the provenance captured by feature selection algorithms, showing how it records all the relevant information in the feature set object.