public class CRFTrainer extends Object implements SequenceTrainer<Label>, WeightedExamples
See:
Lafferty J, McCallum A, Pereira FC. "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data" Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
Constructor and Description |
---|
CRFTrainer(StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
int minibatchSize,
long seed)
Creates a CRFTrainer which uses SGD to learn the parameters.
|
CRFTrainer(StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
long seed)
Sets the minibatch size to 1.
|
CRFTrainer(StochasticGradientOptimiser optimiser,
int epochs,
long seed)
Sets the minibatch size to 1 and the logging interval to 100.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
Returns the number of times the train method has been invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig() |
void |
setShuffle(boolean shuffle)
Turn on or off shuffling of examples.
|
String |
toString() |
CRFModel |
train(SequenceDataset<Label> sequenceExamples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sequence prediction model using the examples in the given data set.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
train
public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed)
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of the minibatches used to aggregate gradients.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed)
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed)
optimiser
- The gradient optimiser to use.epochs
- The number of SGD epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public void setShuffle(boolean shuffle)
This isn't exposed in the constructor as it defaults to on. This method should be used for debugging.
shuffle
- If true shuffle the examples, if false leave them in their current order.public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
SequenceTrainer
train
in interface SequenceTrainer<Label>
sequenceExamples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).public int getInvocationCount()
SequenceTrainer
getInvocationCount
in interface SequenceTrainer<Label>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.