Class CRFTrainer

java.lang.Object
org.tribuo.classification.sgd.crf.CRFTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, SequenceTrainer<Label>, WeightedExamples

public class CRFTrainer extends Object implements SequenceTrainer<Label>, WeightedExamples
A trainer for CRFs using SGD. Modelled after FACTORIE's trainer for CRFs.

See:

 Lafferty J, McCallum A, Pereira FC.
 "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data"
 Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
 
  • Constructor Details

    • CRFTrainer

      public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed)
      Creates a CRFTrainer which uses SGD to learn the parameters.
      Parameters:
      optimiser - The gradient optimiser to use.
      epochs - The number of SGD epochs (complete passes through the training data).
      loggingInterval - Log the loss after this many iterations. If -1 don't log anything.
      minibatchSize - The size of the minibatches used to aggregate gradients.
      seed - A seed for the random number generator, used to shuffle the examples before each epoch.
    • CRFTrainer

      public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed)
      Sets the minibatch size to 1.
      Parameters:
      optimiser - The gradient optimiser to use.
      epochs - The number of SGD epochs (complete passes through the training data).
      loggingInterval - Log the loss after this many iterations. If -1 don't log anything.
      seed - A seed for the random number generator, used to shuffle the examples before each epoch.
    • CRFTrainer

      public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed)
      Sets the minibatch size to 1 and the logging interval to 100.
      Parameters:
      optimiser - The gradient optimiser to use.
      epochs - The number of SGD epochs (complete passes through the training data).
      seed - A seed for the random number generator, used to shuffle the examples before each epoch.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • setShuffle

      public void setShuffle(boolean shuffle)
      Turn on or off shuffling of examples.

      This isn't exposed in the constructor as it defaults to on. This method should be used for debugging.

      Parameters:
      shuffle - If true shuffle the examples, if false leave them in their current order.
    • train

      public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: SequenceTrainer
      Trains a sequence prediction model using the examples in the given data set.
      Specified by:
      train in interface SequenceTrainer<Label>
      Parameters:
      sequenceExamples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: SequenceTrainer
      Returns the number of times the train method has been invoked.
      Specified by:
      getInvocationCount in interface SequenceTrainer<Label>
      Returns:
      The number of times train has been invoked.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public TrainerProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>