Class CRFTrainer
java.lang.Object
org.tribuo.classification.sgd.crf.CRFTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SequenceTrainer<Label>
,WeightedExamples
A trainer for CRFs using SGD. Modelled after FACTORIE's trainer for CRFs.
See:
Lafferty J, McCallum A, Pereira FC. "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data" Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
-
Constructor Summary
ConstructorDescriptionCRFTrainer
(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Creates a CRFTrainer which uses SGD to learn the parameters.CRFTrainer
(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Sets the minibatch size to 1.CRFTrainer
(StochasticGradientOptimiser optimiser, int epochs, long seed) Sets the minibatch size to 1 and the logging interval to 100. -
Method Summary
Modifier and TypeMethodDescriptionint
Returns the number of times the train method has been invoked.void
Used by the OLCUT configuration system, and should not be called by external code.void
setShuffle
(boolean shuffle) Turn on or off shuffling of examples.toString()
train
(SequenceDataset<Label> sequenceExamples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Constructor Details
-
CRFTrainer
public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Creates a CRFTrainer which uses SGD to learn the parameters.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of the minibatches used to aggregate gradients.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
CRFTrainer
public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Sets the minibatch size to 1.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
CRFTrainer
Sets the minibatch size to 1 and the logging interval to 100.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
setShuffle
public void setShuffle(boolean shuffle) Turn on or off shuffling of examples.This isn't exposed in the constructor as it defaults to on. This method should be used for debugging.
- Parameters:
shuffle
- If true shuffle the examples, if false leave them in their current order.
-
train
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainer
Trains a sequence prediction model using the examples in the given data set.- Specified by:
train
in interfaceSequenceTrainer<Label>
- Parameters:
sequenceExamples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:SequenceTrainer
Returns the number of times the train method has been invoked.- Specified by:
getInvocationCount
in interfaceSequenceTrainer<Label>
- Returns:
- The number of times train has been invoked.
-
toString
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-