Class ClassifierChainTrainer

java.lang.Object
org.tribuo.multilabel.baseline.ClassifierChainTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<MultiLabel>

public final class ClassifierChainTrainer extends Object implements Trainer<MultiLabel>
A trainer for a Classifier Chain.

Classifier chains convert binary classifiers into multi-label classifiers by training one classifier per label (similar to the Binary Relevance approach), but in a specific order (the chain). Classifiers further down the chain use the labels from all previously computed classifiers as features, thus allowing the model to incorporate some measure of label dependence.

Choosing the optimal label ordering is tricky as the label dependence is usually unknown, so one popular alternative is to produce an ensemble of randomly ordered chains, which mitigates a poor label ordering by averaging across many orderings.

See:

 Read, J., Pfahringer, B., Holmes, G., & Frank, E.
 "Classifier Chains for Multi-Label Classification"
 Machine Learning, pages 333-359, 2011.
 
  • Field Details

    • CC_PREFIX

      public static final String CC_PREFIX
      The prefix for classifier chain added features.
      See Also:
    • CC_POSITIVE

      public static final String CC_POSITIVE
      The string used in the feature name for positive labels.
      See Also:
    • CC_NEGATIVE

      public static final String CC_NEGATIVE
      The string used in the feature name for negative labels.
      See Also:
    • CC_SEPARATOR

      public static final String CC_SEPARATOR
      The joiner character for classifier chain added features.
      See Also:
  • Constructor Details

    • ClassifierChainTrainer

      public ClassifierChainTrainer(Trainer<Label> innerTrainer, long seed)
      Builds a classifier chain trainer using the specified member trainer and seed.

      The chain is built from n different classifiers, one per label. Later classifiers in the chain see the earlier ground truth labels at training time and at test time they see the earlier predictions from the other chain members.

      This trainer will generate a different random label ordering for each call to train(Dataset).

      Parameters:
      innerTrainer - The trainer to use for each binary classifier.
      seed - The RNG seed for the chain order.
    • ClassifierChainTrainer

      public ClassifierChainTrainer(Trainer<Label> innerTrainer, List<String> labelOrder)
      Builds a classifier chain trainer using the specified member trainer and seed.

      The chain is built from n different classifiers, one per label. Later classifiers in the chain see the earlier ground truth labels at training time and at test time they see the earlier predictions from the other chain members.

      This trainer uses the supplied label ordering, and will throw IllegalArgumentException if the label ordering does not cover all the labels in the training set.

      Parameters:
      innerTrainer - The trainer to use for each binary classifier.
      labelOrder - The label ordering.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • train

      public ClassifierChainModel train(Dataset<MultiLabel> examples)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<MultiLabel>
      Parameters:
      examples - the data set containing the examples.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      public ClassifierChainModel train(Dataset<MultiLabel> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<MultiLabel>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      public ClassifierChainModel train(Dataset<MultiLabel> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
      Description copied from interface: Trainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface Trainer<MultiLabel>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      invocationCount - The invocation counter that the trainer should be set to before training, which in most cases alters the state of the RNG inside this trainer. If the value is set to Trainer.INCREMENT_INVOCATION_COUNT then the invocation count is not changed.
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<MultiLabel>
      Returns:
      The number of train invocations.
    • setInvocationCount

      public void setInvocationCount(int invocationCount)
      Description copied from interface: Trainer
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Specified by:
      setInvocationCount in interface Trainer<MultiLabel>
      Parameters:
      invocationCount - the number of invocations of the train method to simulate
    • getProvenance

      public TrainerProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>