001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence.viterbi;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.Feature;
024import org.tribuo.Model;
025import org.tribuo.Trainer;
026import org.tribuo.classification.Label;
027import org.tribuo.classification.sequence.viterbi.ViterbiModel.ScoreAggregation;
028import org.tribuo.provenance.ModelProvenance;
029import org.tribuo.provenance.TrainerProvenance;
030import org.tribuo.provenance.impl.TrainerProvenanceImpl;
031import org.tribuo.sequence.ImmutableSequenceDataset;
032import org.tribuo.sequence.MutableSequenceDataset;
033import org.tribuo.sequence.SequenceDataset;
034import org.tribuo.sequence.SequenceExample;
035import org.tribuo.sequence.SequenceModel;
036import org.tribuo.sequence.SequenceTrainer;
037
038import java.time.OffsetDateTime;
039import java.util.ArrayList;
040import java.util.List;
041import java.util.Map;
042
043/**
044 * Builds a Viterbi model using the supplied {@link Trainer}.
045 * Has a parameter to control the label features which are added to the features supplied by the data.
046 */
047public final class ViterbiTrainer implements SequenceTrainer<Label> {
048
049    @Config(mandatory = true, description = "Inner trainer for each sequence element.")
050    private Trainer<Label> trainer;
051
052    @Config(mandatory = true, description = "Feature extractor to pull in surrounding label features.")
053    private LabelFeatureExtractor labelFeatureExtractor;
054
055    @Config(mandatory = true, description = "Number of candidate paths.")
056    private int stackSize;
057
058    @Config(mandatory = true, description = "Score aggregation function.")
059    private ScoreAggregation scoreAggregation;
060
061    private int trainInvocationCounter = 0;
062
063    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor,
064                          ScoreAggregation scoreAggregation) {
065        this(trainer, labelFeatureExtractor, -1, scoreAggregation);
066    }
067
068    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize,
069                          ScoreAggregation scoreAggregation) {
070        this.trainer = trainer;
071        this.labelFeatureExtractor = labelFeatureExtractor;
072        this.stackSize = stackSize;
073        this.scoreAggregation = scoreAggregation;
074    }
075
076    /**
077     * The viterbi train method is unique because it delegates to a regular
078     * {@link Model} train method, but before it does, it adds features derived
079     * from preceding labels. The pipeline upstream of this call should not care
080     * that these features are being added - that is, we would not want to make
081     * the upstream logic worry about what kind of trainer will be used and have
082     * conditional logic that says to add special label-derived features if
083     * using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world
084     * label-derived features are generated here and added to the sequence
085     * examples of the passed in dataset. If you pass in a
086     * MutableSequenceDataset, then please be aware that your dataset will be
087     * modified after calling this method and therefore subsequent calls to
088     * other SequenceModel.train methods with your dataset should be avoided. If
089     * you pass in an ImmutableSequenceDataset, then please be aware that your
090     * entire dataset is going to be copied as a MutableSequenceDataset - so
091     * there is a memory penalty.
092     * @param dataset The input dataset.
093     * @param runProvenance Any additional information to record in the provenance.
094     * @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}.
095     */
096    @Override
097    public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) {
098        if (dataset.getOutputInfo().getUnknownCount() > 0) {
099            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
100        }
101        // if stack size isn't specified, then we will calculate it based on the
102        // number of unique output values
103        if (stackSize == -1) {
104            stackSize = dataset.getOutputIDInfo().size();
105        }
106
107        // create a copy of the dataset to a mutable one. See note above.
108        if (dataset instanceof ImmutableSequenceDataset) {
109            dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset);
110        }
111
112        if (!(dataset instanceof MutableSequenceDataset)) {
113            throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName());
114        }
115
116        for (SequenceExample<Label> sequenceExample : dataset) {
117            List<Label> labels = new ArrayList<>();
118
119            for (Example<Label> example : sequenceExample) {
120                List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset,
121                        1.0);
122                example.addAll(labelFeatures);
123                labels.add(example.getOutput());
124            }
125        }
126
127        TrainerProvenance trainerProvenance = getProvenance();
128        ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance);
129        trainInvocationCounter++;
130        Dataset<Label> flatData = dataset.getFlatDataset();
131        Model<Label> model = trainer.train(flatData);
132        return new ViterbiModel("viterbi+" + model.getName(), provenance, model,
133                labelFeatureExtractor, stackSize, scoreAggregation);
134    }
135
136    @Override
137    public int getInvocationCount() {
138        return trainInvocationCounter;
139    }
140
141    private List<Feature> extractFeatures(List<Label> labels,
142                                          MutableSequenceDataset<Label> dataset, double value) {
143        List<Feature> labelFeatures = new ArrayList<>();
144        for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, value)) {
145            dataset.getFeatureMap().add(labelFeature.getName(), labelFeature.getValue());
146            labelFeatures.add(labelFeature);
147        }
148        return labelFeatures;
149    }
150
151    @Override
152    public String toString() {
153        return "ViterbiTrainer(innerTrainer=" + trainer.toString() + ",labelFeatureExtractor=" + labelFeatureExtractor.toString() + ")";
154    }
155
156    @Override
157    public TrainerProvenance getProvenance() {
158        return new TrainerProvenanceImpl(this);
159    }
160}