Class CRFTrainer
java.lang.Object
org.tribuo.classification.sgd.crf.CRFTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,SequenceTrainer<Label>,WeightedExamples
A trainer for CRFs using SGD. Modelled after FACTORIE's trainer for CRFs.
See:
Lafferty J, McCallum A, Pereira FC. "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data" Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
-
Constructor Summary
ConstructorsConstructorDescriptionCRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Creates a CRFTrainer which uses SGD to learn the parameters.CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Sets the minibatch size to 1.CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed) Sets the minibatch size to 1 and the logging interval to 100. -
Method Summary
Modifier and TypeMethodDescriptionintReturns the number of times the train method has been invoked.voidUsed by the OLCUT configuration system, and should not be called by external code.voidsetShuffle(boolean shuffle) Turn on or off shuffling of examples.toString()train(SequenceDataset<Label> sequenceExamples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Constructor Details
-
CRFTrainer
public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Creates a CRFTrainer which uses SGD to learn the parameters.- Parameters:
optimiser- The gradient optimiser to use.epochs- The number of SGD epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.minibatchSize- The size of the minibatches used to aggregate gradients.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
CRFTrainer
public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Sets the minibatch size to 1.- Parameters:
optimiser- The gradient optimiser to use.epochs- The number of SGD epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
CRFTrainer
Sets the minibatch size to 1 and the logging interval to 100.- Parameters:
optimiser- The gradient optimiser to use.epochs- The number of SGD epochs (complete passes through the training data).seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
setShuffle
public void setShuffle(boolean shuffle) Turn on or off shuffling of examples.This isn't exposed in the constructor as it defaults to on. This method should be used for debugging.
- Parameters:
shuffle- If true shuffle the examples, if false leave them in their current order.
-
train
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainerTrains a sequence prediction model using the examples in the given data set.- Specified by:
trainin interfaceSequenceTrainer<Label>- Parameters:
sequenceExamples- the data set containing the examples.runProvenance- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:SequenceTrainerReturns the number of times the train method has been invoked.- Specified by:
getInvocationCountin interfaceSequenceTrainer<Label>- Returns:
- The number of times train has been invoked.
-
toString
-
getProvenance
- Specified by:
getProvenancein interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-