This page describes Tribuo 4.0. View the documentation for Tribuo 4.3 instead.

Configuration Tutorial

This tutorial will show how to use Tribuo's configuration and provenance systems to build models on MNIST (because we wouldn't be doing ML without an MNIST demo). We'll focus on logistic regression, show how many different trainers can be stored in the same configuration, and how the provenance system allows the configuration for a specific run to be regenerated. We'll also briefly look at Tribuo's feature transformation system and see how that integrates into configuration and provenance.

Setup

You'll need to get a copy of the MNIST dataset in the original IDX format.

First the training data:

wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Then the test data:

wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

Tribuo's IDX loader natively reads gzipped files so you don't need to unzip them.

It's Java, so first we load in the necessary Tribuo jars. Here we're using the classification experiments jar, along with the json interop jar to read and write the provenance information.

In [1]:
%jars ./tribuo-classification-experiments-4.0.2-jar-with-dependencies.jar
%jars ./tribuo-json-4.0.2-jar-with-dependencies.jar

Now lets import the packages we need. We'll use a few file manipulation things from Java, and then Tribuo's core packages, the transformation packages, the classification package, classification evaluation package, and then a few things that relate to the provenance system.

In [2]:
import java.nio.file.Files;
import java.nio.file.Paths;
In [3]:
import org.tribuo.*;
import org.tribuo.util.Util;
import org.tribuo.transform.*;
import org.tribuo.transform.transformations.LinearScalingTransformation;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.DescribeConfigurable;
import com.oracle.labs.mlrg.olcut.provenance.*;
import com.oracle.labs.mlrg.olcut.provenance.primitives.*;
import com.oracle.labs.mlrg.olcut.config.json.JsonConfigFactory;

By default OLCUT's ConfigurationManager only understands XML files, this snippet adds JSON support to all ConfigurationManagers in the running JVM. It can be added dynamically on the command line by supplying --config-file-format <fully-qualified-class-name> where the class name is for example com.oracle.labs.mlrg.olcut.config.json.JsonConfigFactory, if you're using OLCUT's CLI options processing.

In [4]:
ConfigurationManager.addFileFormatFactory(new JsonConfigFactory())

How does configuration work?

Tribuo uses a configuration system originally built in Sun Labs, open sourced in the OLCUT library. Classes which can be configured must implement the Configurable interface, and optionally implement a public void postConfig() method, which can be used to check invariants after a class has beeen configured but before it's visible. Configurable classes can mark which of their fields are available for configuration using the @Config annotation, which accepts three arguments: boolean mandatory if the configuration system should error out when the field is not configured, String description a description of the field used as a comment and in the DescribeConfigurable system seen below, and boolean redact which controls if this field value should be written into configuration files or provenance objects.

As configuration is part of the class file rather than the public documented API (because it operates on private fields), OLCUT ships with a CLI utility for inspecting a configurable class and generating an example configuration in any supported configuration format. To use this utility from the command line you can run:

$ java -cp <path-to-jars-including-olcut-core> com.oracle.labs.mlrg.olcut.config.DescribeConfigurable -n <class-name> -o -e xml

where the -n argument denotes what class to describe, -o denotes that an example configuration should be generated, and -e gives the file format to emit the example configuration in.

You can also use the REPL to inspect a configurable class, like so:

In [5]:
var className = "org.tribuo.classification.sgd.linear.LinearSGDTrainer";
var clazz = (Class<? extends Configurable>) Class.forName(className);
Map map = DescribeConfigurable.generateFieldInfo(clazz);

var output = DescribeConfigurable.generateDescription(map);

System.out.println("Class: " + clazz.getCanonicalName() + "\n");
System.out.println(DescribeConfigurable.formatDescription(output));
Class: org.tribuo.classification.sgd.linear.LinearSGDTrainer

Field Name      Type                                         Mandatory Redact Default                                                       Description
epochs          int                                          false     false  5                                                             The number of gradient descent epochs.
loggingInterval int                                          false     false  -1                                                            Log values after this many updates.
minibatchSize   int                                          false     false  1                                                             Minibatch size in SGD.
objective       org.tribuo.classification.sgd.LabelObjective false     false  LogMulticlass                                                 The classification objective function to use.
optimiser       org.tribuo.math.StochasticGradientOptimiser  false     false  AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0) The gradient optimiser to use.
seed            long                                         false     false  12345                                                         Seed for the RNG used to shuffle elements.
shuffle         boolean                                      false     false  true                                                          Shuffle the data before each epoch. Only turn off for debugging.

And also to print out an example config file:

In [6]:
ByteArrayOutputStream writer = new ByteArrayOutputStream();
DescribeConfigurable.writeExampleConfig(writer,"json",clazz,map);
System.out.println(writer.toString("UTF-8"));
{
  "config" : {
    "components" : [ {
      "name" : "example",
      "type" : "org.tribuo.classification.sgd.linear.LinearSGDTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "seed" : "0",
        "minibatchSize" : "0",
        "shuffle" : "false",
        "epochs" : "0",
        "optimiser" : "StochasticGradientOptimiser-instance",
        "loggingInterval" : "0",
        "objective" : "LabelObjective-instance"
      }
    } ]
  }
}

At the moment using it from the REPL is poorly specified (note the lack of generic type information from DescribeConfigurable.generateFieldInfo), we'll fix that in the next OLCUT release.

Using a configuration file

We're going to read in an example configuration file, in JSON format. This configuration knows about a bunch of different trainers, and also the training and testing MNIST data sources. In the tutorials directory we supply both the JSON and XML versions of this file, and the remainder of this tutorial is completely agnostic to which one is used.

In [7]:
var configPath = Paths.get("configuration","example-config.json");
String.join("\n",Files.readAllLines(configPath));
Out[7]:
{
  "config" : {
    "components" : [ {
      "name" : "mnist-test",
      "type" : "org.tribuo.datasource.IDXDataSource",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "featuresPath" : "t10k-images-idx3-ubyte.gz",
        "outputPath" : "t10k-labels-idx1-ubyte.gz",
        "outputFactory" : "label-factory"
      }
    }, {
      "name" : "mnist-train",
      "type" : "org.tribuo.datasource.IDXDataSource",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "featuresPath" : "train-images-idx3-ubyte.gz",
        "outputPath" : "train-labels-idx1-ubyte.gz",
        "outputFactory" : "label-factory"
      }
    }, {
      "name" : "adagrad",
      "type" : "org.tribuo.math.optimisers.AdaGrad",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "epsilon" : "0.01",
        "initialLearningRate" : "0.5"
      }
    }, {
      "name" : "log",
      "type" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "label-factory",
      "type" : "org.tribuo.classification.LabelFactory",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "gini",
      "type" : "org.tribuo.classification.dtree.impurity.GiniIndex",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "cart",
      "type" : "org.tribuo.classification.dtree.CARTClassificationTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "maxDepth" : "6",
        "impurity" : "gini",
        "seed" : "12345",
        "fractionFeaturesInSplit" : "0.5"
      }
    }, {
      "name" : "entropy",
      "type" : "org.tribuo.classification.dtree.impurity.Entropy",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "logistic",
      "type" : "org.tribuo.classification.sgd.linear.LinearSGDTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "seed" : "1",
        "minibatchSize" : "1",
        "epochs" : "2",
        "optimiser" : "adagrad",
        "objective" : "log",
        "loggingInterval" : "10000"
      }
    }, {
      "name" : "xgboost",
      "type" : "org.tribuo.classification.xgboost.XGBoostClassificationTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "numTrees" : "10",
        "maxDepth" : "4",
        "eta" : "0.5",
        "seed" : "1",
        "minChildWeight" : "1.0",
        "subsample" : "1.0",
        "nThread" : "6",
        "gamma" : "0.1"
      }
    } ]
  }
}

Now we'll make a ConfigurationManager and hand it the configuration file to load. Our configuration system also supports CLI options which can load things out of the supplied configuration files. We have examples of this in each of the simple TrainTest demo classes in each prediction backend.

In [8]:
var cm = new ConfigurationManager(configPath.toString());

First we'll load in the training and testing DataSources (as instances of IDXDataSource), pass them into two Datasets to aggregate the appropriate metadata, and we'll make the evaluator for later use.

In [9]:
DataSource<Label> mnistTrain = (DataSource<Label>) cm.lookup("mnist-train");
DataSource<Label> mnistTest = (DataSource<Label>) cm.lookup("mnist-test");
var trainData = new MutableDataset<>(mnistTrain);
var testData = new MutableDataset<>(mnistTest);
var evaluator = new LabelEvaluator();
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainData.size(),trainData.getFeatureMap().size(),trainData.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testData.size(),testData.getFeatureMap().size(),testData.getOutputInfo().size()));
Training data size = 60000, number of features = 717, number of classes = 10
Testing data size = 10000, number of features = 668, number of classes = 10

Loading in trainers from the configuration

Our configuration file contains a number of different trainers, so let's pull them out and take a look.

The first one we'll see is a CART decision tree, with a max tree depth of 6.

In [10]:
var cart = (Trainer<Label>) cm.lookup("cart");
cart
Out[10]:
CARTClassificationTrainer(maxDepth=6,minChildWeight=5.0,fractionFeaturesInSplit=0.5,impurity=GiniIndex,seed=12345)

Next we'll load an XGBoost trainer, using 10 trees, 6 computation threads, and some regularisation parameters. Note: Tribuo's XGBoost support relies upon the Maven Central XGBoost jar from DMLC which contains macOS and Linux binaries, on Windows please compile DMLC's XGBoost jar from source and rebuild Tribuo.

In [11]:
var xgb = (Trainer<Label>) cm.lookup("xgboost");
xgb
Out[11]:
XGBoostTrainer(numTrees=10,parameters{colsample_bytree=1.0, silent=1, seed=1, max_depth=4, booster=gbtree, objective=multi:softprob, lambda=1.0, eta=0.5, nthread=6, alpha=1.0, subsample=1.0, gamma=0.1, min_child_weight=1.0})

Finally we'll load in a logistic regression trainer, using AdaGrad as the gradient optimizer.

In [12]:
var logistic = (Trainer<Label>) cm.lookup("logistic");
logistic
Out[12]:
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=0.5,epsilon=0.01,initialValue=0.0),epochs=2,minibatchSize=1,seed=1)

We can also load a list in containing all the Trainer implementations in this config file. Note: the config system by default returns the same instance when it's queried for the same named config. So the list contains references to the objects we've already loaded.

In [13]:
var trainers = (List<Trainer>) cm.lookupAll(Trainer.class);
System.out.println("Loaded " + trainers.size() + " trainers.");
Loaded 3 trainers.

Training the model and extracting configuration

We're going to focus on the logistic regression trainer now, so let's train a logistic regression model on our MNIST training set.

In [14]:
var lrStartTime = System.currentTimeMillis();
var lrModel = logistic.train(trainData);
var lrEndTime = System.currentTimeMillis();
System.out.println("Training logistic regression took " + Util.formatDuration(lrStartTime,lrEndTime));
Training logistic regression took (00:00:04:710)

We can inspect the trained model for it's provenance, as we saw in the Classification tutorial.

The new step is extracting a configuration from that provenance. The ProvenanceUtil.extractConfiguration() call returns a List<ConfigurationData> which is the object representation of a configuration file. We can see that it's extracted configurations for 5 objects from our single model, we'll look at those after we've written out the file.

In [15]:
var provenance = lrModel.getProvenance();
var provConfig = ProvenanceUtil.extractConfiguration(provenance);
provConfig.size()
Out[15]:
5

The ConfigurationManager is the way we can generate a configuration file from the object representation. We create a new ConfigurationManager, add the configuration we extracted from the provenance, and then write it out to a new JSON file.

In [16]:
var outputFile = "mnist-logistic-config.json";
var newCM = new ConfigurationManager();
newCM.addConfiguration(provConfig);
newCM.save(new File(outputFile),true);
String.join("\n",Files.readAllLines(Paths.get(outputFile)))
Out[16]:
{
  "config" : {
    "components" : [ {
      "name" : "idxdatasource-1",
      "type" : "org.tribuo.datasource.IDXDataSource",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "outputPath" : "/Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz",
        "outputFactory" : "labelfactory-4",
        "featuresPath" : "/Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz"
      }
    }, {
      "name" : "linearsgdtrainer-0",
      "type" : "org.tribuo.classification.sgd.linear.LinearSGDTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "seed" : "1",
        "minibatchSize" : "1",
        "shuffle" : "true",
        "epochs" : "2",
        "optimiser" : "adagrad-2",
        "objective" : "logmulticlass-3",
        "loggingInterval" : "10000"
      }
    }, {
      "name" : "adagrad-2",
      "type" : "org.tribuo.math.optimisers.AdaGrad",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "epsilon" : "0.01",
        "initialLearningRate" : "0.5",
        "initialValue" : "0.0"
      }
    }, {
      "name" : "labelfactory-4",
      "type" : "org.tribuo.classification.LabelFactory",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "logmulticlass-3",
      "type" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "export" : "false",
      "import" : "false"
    } ]
  }
}

The five elements of the configuration are: the training data "idxdatasource-1", the logistic regression "linearsgdtrainer-0", the training log loss function "logmulticlass-3", the AdaGrad gradient optimizer "adagrad-2", and the label factory "labelfactory-4". The only unexpected part is the LabelFactory which is the factory that converts Strings into Label instances.

Rebuilding a model from it's configuration

Now to reconstruct our model, we can load in the Trainer and DataSource from the new ConfigurationManager, pass the source into a Dataset, and finally call train on the new trainer supplying the new dataset.

In [17]:
var newTrainer = (Trainer<Label>) newCM.lookup("linearsgdtrainer-0");
var newSource = (DataSource<Label>) newCM.lookup("idxdatasource-1");
var newDataset = new MutableDataset<>(newSource);
var newModel = newTrainer.train(newDataset, Collections.singletonMap("reconfigured-model",new BooleanProvenance("reconfigured-model",true)));

First we'll confirm that the old model and new models aren't equal (as they have different timestamps, among other provenance checks).

In [18]:
lrModel.getProvenance().equals(newModel.getProvenance())
Out[18]:
false

Now we'll evaluate the first model:

In [19]:
var lrEvaluator = evaluator.evaluate(lrModel,testData);
System.out.println(lrEvaluator.toString());
System.out.println(lrEvaluator.getConfusionMatrix().toString());
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         904          76          21       0.922       0.977       0.949
1                           1,135       1,072          63          18       0.944       0.983       0.964
2                           1,032         856         176          56       0.829       0.939       0.881
3                           1,010         844         166          84       0.836       0.909       0.871
4                             982         888          94          72       0.904       0.925       0.915
5                             892         751         141         143       0.842       0.840       0.841
6                             958         938          20         139       0.979       0.871       0.922
7                           1,028         963          65         133       0.937       0.879       0.907
8                             974         892          82         363       0.916       0.711       0.800
9                           1,009         801         208          62       0.794       0.928       0.856
Total                      10,000       8,909       1,091       1,091
Accuracy                                                                    0.891
Micro Average                                                               0.891       0.891       0.891
Macro Average                                                               0.890       0.896       0.890
Balanced Error Rate                                                         0.110
               0       1       2       3       4       5       6       7       8       9
0            904       0       2       3       1      20      26       4      18       2
1              0   1,072       7       3       0       2       6       2      43       0
2              3       6     856      26       5       7      39       8      80       2
3              1       0      13     844       2      64       7      14      62       3
4              0       0       7       2     888       1      22      15      20      27
5              9       1       1      27       6     751      18       7      68       4
6              3       1       2       1       1       9     938       1       2       0
7              1       5      18       6       4       1       0     963       9      21
8              1       3       6       9       9      25      20       6     892       3
9              3       2       0       7      44      14       1      76      61     801

It's about what we'd expect for a linear model on MNIST. Not SOTA, but it'll do for now.

Now let's check the new model:

In [20]:
var newEvaluator = evaluator.evaluate(newModel,testData);
System.out.println(newEvaluator.toString());
System.out.println(newEvaluator.getConfusionMatrix().toString());
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         904          76          21       0.922       0.977       0.949
1                           1,135       1,072          63          18       0.944       0.983       0.964
2                           1,032         856         176          56       0.829       0.939       0.881
3                           1,010         844         166          84       0.836       0.909       0.871
4                             982         888          94          72       0.904       0.925       0.915
5                             892         751         141         143       0.842       0.840       0.841
6                             958         938          20         139       0.979       0.871       0.922
7                           1,028         963          65         133       0.937       0.879       0.907
8                             974         892          82         363       0.916       0.711       0.800
9                           1,009         801         208          62       0.794       0.928       0.856
Total                      10,000       8,909       1,091       1,091
Accuracy                                                                    0.891
Micro Average                                                               0.891       0.891       0.891
Macro Average                                                               0.890       0.896       0.890
Balanced Error Rate                                                         0.110
               0       1       2       3       4       5       6       7       8       9
0            904       0       2       3       1      20      26       4      18       2
1              0   1,072       7       3       0       2       6       2      43       0
2              3       6     856      26       5       7      39       8      80       2
3              1       0      13     844       2      64       7      14      62       3
4              0       0       7       2     888       1      22      15      20      27
5              9       1       1      27       6     751      18       7      68       4
6              3       1       2       1       1       9     938       1       2       0
7              1       5      18       6       4       1       0     963       9      21
8              1       3       6       9       9      25      20       6     892       3
9              3       2       0       7      44      14       1      76      61     801

We can see that both models perform identically. This is because our provenance system records the RNG seeds used at all points, and Tribuo is scrupulous about how and when it uses PRNGs. If you find a model reconstruction that gives a different answer (unless you're using XGBoost, which has some non-determinism beyond our control) then file an issue on our GitHub as that's a bug.

What else lives in the Provenance?

These evaluations have provenance in the same way the models do, and we can use a pretty printer in OLCUT to make it a little more human readable.

In addition to the configuration information like the gradient optimiser and RNG seed, the provenance includes run specific information like the "reconfigured-model" flag we added, along with a hash of the data, timestamps for the various data files involved, and a timestamp for the model creation and dataset creation.

In [21]:
var evalProvenance = newEvaluator.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(evalProvenance));
EvaluationProvenance(
	class-name = org.tribuo.provenance.EvaluationProvenance
	model-provenance = LinearSGDModel(
			class-name = org.tribuo.classification.sgd.linear.LinearSGDModel
			dataset = MutableDataset(
					class-name = org.tribuo.MutableDataset
					datasource = IDXDataSource(
							class-name = org.tribuo.datasource.IDXDataSource
							outputPath = /Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz
							outputFactory = LabelFactory(
									class-name = org.tribuo.classification.LabelFactory
								)
							featuresPath = /Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz
							features-file-modified-time = 2000-07-21T14:20:24-04:00
							output-resource-hash = 3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C
							datasource-creation-time = 2020-11-05T10:31:35.857394-05:00
							output-file-modified-time = 2000-07-21T14:20:27-04:00
							idx-feature-type = UBYTE
							features-resource-hash = 440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609
							host-short-name = DataSource
						)
					transformations = List[]
					is-sequence = false
					is-dense = false
					num-examples = 60000
					num-features = 717
					num-outputs = 10
					tribuo-version = 4.0.2
				)
			trainer = LinearSGDTrainer(
					class-name = org.tribuo.classification.sgd.linear.LinearSGDTrainer
					seed = 1
					minibatchSize = 1
					shuffle = true
					epochs = 2
					optimiser = AdaGrad(
							class-name = org.tribuo.math.optimisers.AdaGrad
							epsilon = 0.01
							initialLearningRate = 0.5
							initialValue = 0.0
							host-short-name = StochasticGradientOptimiser
						)
					objective = LogMulticlass(
							class-name = org.tribuo.classification.sgd.objectives.LogMulticlass
							host-short-name = LabelObjective
						)
					loggingInterval = 10000
					train-invocation-count = 0
					is-sequence = false
					host-short-name = Trainer
				)
			trained-at = 2020-11-05T10:31:41.706114-05:00
			instance-values = Map{
				reconfigured-model=true
			}
			tribuo-version = 4.0.2
		)
	dataset-provenance = MutableDataset(
			class-name = org.tribuo.MutableDataset
			datasource = IDXDataSource(
					class-name = org.tribuo.datasource.IDXDataSource
					outputPath = /Users/apocock/Development/Tribuo/tutorials/t10k-labels-idx1-ubyte.gz
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					featuresPath = /Users/apocock/Development/Tribuo/tutorials/t10k-images-idx3-ubyte.gz
					features-file-modified-time = 2000-07-21T14:19:56-04:00
					output-resource-hash = F7AE60F92E00EC6DEBD23A6088C31DBD2371ECA3FFA0DEFAEFB259924204AEC6
					datasource-creation-time = 2020-11-05T10:31:23.188908-05:00
					output-file-modified-time = 2000-07-21T14:20:05-04:00
					idx-feature-type = UBYTE
					features-resource-hash = 8D422C7B0A1C1C79245A5BCF07FE86E33EEAFEE792B84584AEC276F5A2DBC4E6
					host-short-name = DataSource
				)
			transformations = List[]
			is-sequence = false
			is-dense = false
			num-examples = 10000
			num-features = 668
			num-outputs = 10
			tribuo-version = 4.0.2
		)
	tribuo-version = 4.0.2
)

Feature Transformations

We can take the new trainer, wrap it programmatically in a TransfomTrainer which rescales the input features into the range [0,1], and still generate provenance and configuration automatically as the model is trained.

In [22]:
var transformations = new TransformationMap(List.of(new LinearScalingTransformation(0,1)));
var transformed = new TransformTrainer(newTrainer,transformations);
var transformStart = System.currentTimeMillis();
var transformedModel = transformed.train(newDataset);
var transformEnd = System.currentTimeMillis();
System.out.println("Training transformed logistic regression took " + Util.formatDuration(transformStart,transformEnd));
Training transformed logistic regression took (00:00:08:613)

Now we'll evaluate the rescaled model. Here we see that rescaling the data into the zero-one range improves the linear model performance a couple of percent as all the data is now on the same scale. As expected it's still not SOTA, but we're not using a huge CNN or some other complex model, for that you can try out our TensorFlow interface, or use the XGBoost trainer we loaded in from the original configuration file.

In [23]:
LabelEvaluation transformedEvaluator = evaluator.evaluate(transformedModel,testData);
System.out.println(transformedEvaluator.toString());
System.out.println(transformedEvaluator.getConfusionMatrix().toString());
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         957          23          40       0.977       0.960       0.968
1                           1,135       1,109          26          36       0.977       0.969       0.973
2                           1,032         940          92          90       0.911       0.913       0.912
3                           1,010         927          83         141       0.918       0.868       0.892
4                             982         914          68          73       0.931       0.926       0.928
5                             892         813          79         183       0.911       0.816       0.861
6                             958         892          66          45       0.931       0.952       0.941
7                           1,028         918         110          54       0.893       0.944       0.918
8                             974         753         221          60       0.773       0.926       0.843
9                           1,009         926          83         129       0.918       0.878       0.897
Total                      10,000       9,149         851         851
Accuracy                                                                    0.915
Micro Average                                                               0.915       0.915       0.915
Macro Average                                                               0.914       0.915       0.913
Balanced Error Rate                                                         0.086
               0       1       2       3       4       5       6       7       8       9
0            957       0       1       2       1      12       4       2       1       0
1              0   1,109      10       3       0       2       3       2       6       0
2              4       9     940      18       9       7      11      11      19       4
3              6       0      25     927       0      26       2       7       9       8
4              1       1       7       4     914       0       9       7       4      35
5              7       1       2      30       8     813       9       3      18       1
6              8       2      14       3       8      27     892       2       2       0
7              1       7      17      19       8       1       0     918       1      56
8              7       9      13      46      11      93       7      10     753      25
9              6       7       1      16      28      15       0      10       0     926

We can emit a configuration which includes both the transformation trainer and the original trainer pulled from the old configuration. We'll write it out to a byte array rather than putting it on disk, but the process is the same.

In [24]:
var transformedProvConfig = ProvenanceUtil.extractConfiguration(transformedModel.getProvenance());
var baos = new ByteArrayOutputStream();
newCM = new ConfigurationManager();
newCM.addConfiguration(transformedProvConfig);
newCM.save(baos,"json",true);
baos.toString();
Out[24]:
{
  "config" : {
    "components" : [ {
      "name" : "linearscalingtransformation-4",
      "type" : "org.tribuo.transform.transformations.LinearScalingTransformation",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "targetMax" : "1.0",
        "targetMin" : "0.0"
      }
    }, {
      "name" : "labelfactory-7",
      "type" : "org.tribuo.classification.LabelFactory",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "adagrad-5",
      "type" : "org.tribuo.math.optimisers.AdaGrad",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "epsilon" : "0.01",
        "initialLearningRate" : "0.5",
        "initialValue" : "0.0"
      }
    }, {
      "name" : "linearsgdtrainer-2",
      "type" : "org.tribuo.classification.sgd.linear.LinearSGDTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "seed" : "1",
        "minibatchSize" : "1",
        "shuffle" : "true",
        "epochs" : "2",
        "optimiser" : "adagrad-5",
        "objective" : "logmulticlass-6",
        "loggingInterval" : "10000"
      }
    }, {
      "name" : "transformtrainer-0",
      "type" : "org.tribuo.transform.TransformTrainer",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "transformations" : "transformationmap-1",
        "densify" : "false",
        "innerTrainer" : "linearsgdtrainer-2"
      }
    }, {
      "name" : "logmulticlass-6",
      "type" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "export" : "false",
      "import" : "false"
    }, {
      "name" : "idxdatasource-3",
      "type" : "org.tribuo.datasource.IDXDataSource",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "outputPath" : "/Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz",
        "outputFactory" : "labelfactory-7",
        "featuresPath" : "/Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz"
      }
    }, {
      "name" : "transformationmap-1",
      "type" : "org.tribuo.transform.TransformationMap",
      "export" : "false",
      "import" : "false",
      "properties" : {
        "featureTransformationList" : { },
        "globalTransformations" : [ {
          "item" : "linearscalingtransformation-4"
        } ]
      }
    } ]
  }
}

Aside from the names (which have different tag numbers) we can see that this configuration is identical to the previous one, but with the addition of the transformtrainer-0 and it's dependents.

Conclusion

We've taken a closer look at Tribuo's configuration and provenance systems, showing how to train a model using a configuration file, how to inspect the model's provenance, extract it's configuration, and finally how to combine that extracted configuration with other programmatic elements of the Tribuo library (in this case the feature transformation system). We saw that the provenance combines both the configuration of the trainer and the datasource, along with runtime information extracted from the dataset itself (e.g., timestamps and file hashes).

Tribuo's configuration system is integrated into a CLI options/arguments parsing system, which can be used to override elements from the configuration file. The values from the options are then stored in the ConfigurationManager and appear in the provenance and downstream configuration objects as expected. Tribuo also provides a redaction system for configuration files (e.g., to ensure a password isn't stored in the provenance) and for provenance objects themselves (e.g., to remove the data provenance from a trained model), which aids model deployment to untrusted or less trusted systems.