001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence.viterbi;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Feature;
022import org.tribuo.Model;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.provenance.ModelProvenance;
026import org.tribuo.sequence.SequenceDataset;
027import org.tribuo.sequence.SequenceExample;
028import org.tribuo.sequence.SequenceModel;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.Collections;
033import java.util.Comparator;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.stream.Collectors;
038
039/**
040 * An implementation of a viterbi model.
041 */
042public class ViterbiModel extends SequenceModel<Label> {
043
044    private static final long serialVersionUID = 1L;
045
046    /**
047     * Types of label score aggregation.
048     */
049    public enum ScoreAggregation {
050        ADD, MULTIPLY
051    }
052
053    private final Model<Label> model;
054
055    private final LabelFeatureExtractor labelFeatureExtractor;
056
057    /**
058     * Specifies the maximum number of candidate paths to keep track of. In general, this number
059     * should be higher than the number of possible classifications at any given point in the
060     * sequence. This guarantees that highest-possible scoring sequence will be returned. If,
061     * however, the number of possible classifications is quite high and/or you are concerned about
062     * throughput performance, then you may want to reduce the number of candidate paths to
063     * maintain.
064     */
065    private final int stackSize;
066
067    /**
068     * Specifies the score aggregation algorithm.
069     */
070    private final ScoreAggregation scoreAggregation;
071
072    ViterbiModel(String name, ModelProvenance description,
073                        Model<Label> model, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ScoreAggregation scoreAggregation) {
074        super(name, description, model.getFeatureIDMap(), model.getOutputIDInfo());
075        this.model = model;
076        this.labelFeatureExtractor = labelFeatureExtractor;
077        this.stackSize = stackSize;
078        this.scoreAggregation = scoreAggregation;
079    }
080
081    @Override
082    public List<List<Prediction<Label>>> predict(SequenceDataset<Label> examples) {
083        List<List<Prediction<Label>>> predictions = new ArrayList<>();
084        for (SequenceExample<Label> e : examples) {
085            predictions.add(predict(e));
086        }
087        return predictions;
088    }
089
090    @Override
091    public List<Prediction<Label>> predict(SequenceExample<Label> examples) {
092        if (stackSize == 1) {
093            List<Label> labels = new ArrayList<>();
094            List<Prediction<Label>> returnValues = new ArrayList<>();
095            for (Example<Label> example : examples) {
096                List<Feature> labelFeatures = extractFeatures(labels);
097                example.addAll(labelFeatures);
098                Prediction<Label> prediction = model.predict(example);
099                labels.add(prediction.getOutput());
100                returnValues.add(prediction);
101            }
102            return returnValues;
103        } else {
104            return viterbi(examples);
105        }
106
107    }
108
109    private List<Feature> extractFeatures(List<Label> labels) {
110        List<Feature> labelFeatures = new ArrayList<>();
111        for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, 1.0)) {
112            int id = featureIDMap.getID(labelFeature.getName());
113            if (id > -1) {
114                labelFeatures.add(labelFeature);
115            }
116        }
117        return labelFeatures;
118    }
119
120    /**
121     * This implementation of Viterbi requires at most stackSize * sequenceLength calls to the
122     * classifier. If this proves to be too expensive, then consider using a smaller stack size.
123     *
124     * @param examples a sequence-worth of features. Each {@code List<Feature>} in features should correspond to
125     *                 all of the features for a given element in a sequence to be classified.
126     * @return a list of Predictions - one for each member of the sequence.
127     * @see LabelFeatureExtractor
128     */
129    private List<Prediction<Label>> viterbi(SequenceExample<Label> examples) {
130        // find the best paths through the label lattice
131        Collection<Path> paths = null;
132        int[] numUsed = new int[examples.size()];
133        int i = 0;
134        for (Example<Label> example : examples) {
135            // if this is the first instance, start new paths for each label
136            if (paths == null) {
137                paths = new ArrayList<>();
138                Prediction<Label> prediction = this.model.predict(example);
139                numUsed[i] = prediction.getNumActiveFeatures();
140                Map<String, Label> distribution = prediction.getOutputScores();
141                for (Label label : this.getTopLabels(distribution)) {
142                    paths.add(new Path(label, label.getScore(), null));
143                }
144            } else {
145                // for later instances, find the best previous path for each label
146                Map<Label, Path> maxPaths = new HashMap<>();
147                for (Path path : paths) {
148                    Example<Label> clonedExample = example.copy();
149                    List<Label> previousLabels = new ArrayList<>(path.labels);
150                    List<Feature> labelFeatures = extractFeatures(previousLabels);
151                    clonedExample.addAll(labelFeatures);
152                    Prediction<Label> prediction = this.model.predict(clonedExample);
153                    // TODO this isn't quite correct as it includes label features.
154                    numUsed[i] = prediction.getNumActiveFeatures();
155                    Map<String, Label> distribution = prediction.getOutputScores();
156
157                    for (Label label : this.getTopLabels(distribution)) {
158                        double labelScore = label.getScore();
159                        double score = this.scoreAggregation == ScoreAggregation.ADD ? path.score + labelScore : path.score * labelScore;
160                        Path maxPath = maxPaths.get(label);
161                        if (maxPath == null || score > maxPath.score) {
162                            maxPaths.put(label, new Path(label, score, path));
163                        }
164                    }
165                }
166                paths = maxPaths.values();
167            }
168            i++;
169        }
170
171        Path maxPath = Collections.max(paths);
172
173        ArrayList<Prediction<Label>> output = new ArrayList<>();
174
175        for (int j = 0; j < examples.size(); j++) {
176            Example<Label> e = examples.get(j);
177            output.add(new Prediction<>(maxPath.labels.get(j), numUsed[j], e));
178        }
179
180        return output;
181    }
182
183    protected List<Label> getTopLabels(Map<String, Label> distribution) {
184        return getTopLabels(distribution, this.stackSize);
185    }
186
187    protected static List<Label> getTopLabels(Map<String, Label> distribution, int stackSize) {
188        return distribution.values().stream().sorted(Comparator.comparingDouble(Label::getScore).reversed()).limit(stackSize)
189                .collect(Collectors.toList());
190        // get just the labels that fit within the stack
191    }
192
193    private static class Path implements Comparable<Path> {
194
195        public final double score;
196
197        public final Path parent;
198
199        public final List<Label> labels;
200
201        public Path(Label label, double score, Path parent) {
202            this.score = score;
203            this.parent = parent;
204            this.labels = new ArrayList<>();
205            if (this.parent != null) {
206                this.labels.addAll(this.parent.labels);
207            }
208            this.labels.add(label);
209        }
210
211        @Override
212        public int compareTo(Path that) {
213            return Double.compare(this.score, that.score);
214        }
215
216    }
217
218    public int getStackSize() {
219        return stackSize;
220    }
221
222    public ScoreAggregation getScoreAggregation() {
223        return scoreAggregation;
224    }
225
226    @Override
227    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
228        return model.getTopFeatures(n);
229    }
230
231}