Class ViterbiTrainer
java.lang.Object
org.tribuo.classification.sequence.viterbi.ViterbiTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SequenceTrainer<Label>
Builds a Viterbi model using the supplied
Trainer
.
Has a parameter to control the label features which are added to the features supplied by the data.-
Constructor Summary
ConstructorDescriptionViterbiTrainer
(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ViterbiModel.ScoreAggregation scoreAggregation) Constructs a ViterbiTrainer wrapping the specified trainer.ViterbiTrainer
(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, ViterbiModel.ScoreAggregation scoreAggregation) Constructs a ViterbiTrainer wrapping the specified trainer, with an unbounded stack size. -
Method Summary
Modifier and TypeMethodDescriptionint
Returns the number of times the train method has been invoked.toString()
train
(SequenceDataset<Label> dataset, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) The viterbi train method is unique because it delegates to a regularModel
train method, but before it does, it adds features derived from preceding labels.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Constructor Details
-
ViterbiTrainer
public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, ViterbiModel.ScoreAggregation scoreAggregation) Constructs a ViterbiTrainer wrapping the specified trainer, with an unbounded stack size.- Parameters:
trainer
- The trainer to wrap.labelFeatureExtractor
- The feature extraction function for labels.scoreAggregation
- The score aggregation function.
-
ViterbiTrainer
public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ViterbiModel.ScoreAggregation scoreAggregation) Constructs a ViterbiTrainer wrapping the specified trainer.- Parameters:
trainer
- The trainer to wrap.labelFeatureExtractor
- The feature extraction function for labels.stackSize
- The stack size.scoreAggregation
- The score aggregation function.
-
-
Method Details
-
train
public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) The viterbi train method is unique because it delegates to a regularModel
train method, but before it does, it adds features derived from preceding labels. The pipeline upstream of this call should not care that these features are being added - that is, we would not want to make the upstream logic worry about what kind of trainer will be used and have conditional logic that says to add special label-derived features if using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world label-derived features are generated here and added to the sequence examples of the passed in dataset. If you pass in a MutableSequenceDataset, then please be aware that your dataset will be modified after calling this method and therefore subsequent calls to other SequenceModel.train methods with your dataset should be avoided. If you pass in an ImmutableSequenceDataset, then please be aware that your entire dataset is going to be copied as a MutableSequenceDataset - so there is a memory penalty.- Specified by:
train
in interfaceSequenceTrainer<Label>
- Parameters:
dataset
- The input dataset.runProvenance
- Any additional information to record in the provenance.- Returns:
- A
SequenceModel
using Viterbi wrapped around an innerModel
.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:SequenceTrainer
Returns the number of times the train method has been invoked.- Specified by:
getInvocationCount
in interfaceSequenceTrainer<Label>
- Returns:
- The number of times train has been invoked.
-
toString
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-