001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence.viterbi; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Feature; 022import org.tribuo.Model; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.provenance.ModelProvenance; 026import org.tribuo.sequence.SequenceDataset; 027import org.tribuo.sequence.SequenceExample; 028import org.tribuo.sequence.SequenceModel; 029 030import java.util.ArrayList; 031import java.util.Collection; 032import java.util.Collections; 033import java.util.Comparator; 034import java.util.HashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.stream.Collectors; 038 039/** 040 * An implementation of a viterbi model. 041 */ 042public class ViterbiModel extends SequenceModel<Label> { 043 044 private static final long serialVersionUID = 1L; 045 046 /** 047 * Types of label score aggregation. 048 */ 049 public enum ScoreAggregation { 050 ADD, MULTIPLY 051 } 052 053 private final Model<Label> model; 054 055 private final LabelFeatureExtractor labelFeatureExtractor; 056 057 /** 058 * Specifies the maximum number of candidate paths to keep track of. In general, this number 059 * should be higher than the number of possible classifications at any given point in the 060 * sequence. This guarantees that highest-possible scoring sequence will be returned. If, 061 * however, the number of possible classifications is quite high and/or you are concerned about 062 * throughput performance, then you may want to reduce the number of candidate paths to 063 * maintain. 064 */ 065 private final int stackSize; 066 067 /** 068 * Specifies the score aggregation algorithm. 069 */ 070 private final ScoreAggregation scoreAggregation; 071 072 ViterbiModel(String name, ModelProvenance description, 073 Model<Label> model, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ScoreAggregation scoreAggregation) { 074 super(name, description, model.getFeatureIDMap(), model.getOutputIDInfo()); 075 this.model = model; 076 this.labelFeatureExtractor = labelFeatureExtractor; 077 this.stackSize = stackSize; 078 this.scoreAggregation = scoreAggregation; 079 } 080 081 @Override 082 public List<List<Prediction<Label>>> predict(SequenceDataset<Label> examples) { 083 List<List<Prediction<Label>>> predictions = new ArrayList<>(); 084 for (SequenceExample<Label> e : examples) { 085 predictions.add(predict(e)); 086 } 087 return predictions; 088 } 089 090 @Override 091 public List<Prediction<Label>> predict(SequenceExample<Label> examples) { 092 if (stackSize == 1) { 093 List<Label> labels = new ArrayList<>(); 094 List<Prediction<Label>> returnValues = new ArrayList<>(); 095 for (Example<Label> example : examples) { 096 List<Feature> labelFeatures = extractFeatures(labels); 097 example.addAll(labelFeatures); 098 Prediction<Label> prediction = model.predict(example); 099 labels.add(prediction.getOutput()); 100 returnValues.add(prediction); 101 } 102 return returnValues; 103 } else { 104 return viterbi(examples); 105 } 106 107 } 108 109 private List<Feature> extractFeatures(List<Label> labels) { 110 List<Feature> labelFeatures = new ArrayList<>(); 111 for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, 1.0)) { 112 int id = featureIDMap.getID(labelFeature.getName()); 113 if (id > -1) { 114 labelFeatures.add(labelFeature); 115 } 116 } 117 return labelFeatures; 118 } 119 120 /** 121 * This implementation of Viterbi requires at most stackSize * sequenceLength calls to the 122 * classifier. If this proves to be too expensive, then consider using a smaller stack size. 123 * 124 * @param examples a sequence-worth of features. Each {@code List<Feature>} in features should correspond to 125 * all of the features for a given element in a sequence to be classified. 126 * @return a list of Predictions - one for each member of the sequence. 127 * @see LabelFeatureExtractor 128 */ 129 private List<Prediction<Label>> viterbi(SequenceExample<Label> examples) { 130 // find the best paths through the label lattice 131 Collection<Path> paths = null; 132 int[] numUsed = new int[examples.size()]; 133 int i = 0; 134 for (Example<Label> example : examples) { 135 // if this is the first instance, start new paths for each label 136 if (paths == null) { 137 paths = new ArrayList<>(); 138 Prediction<Label> prediction = this.model.predict(example); 139 numUsed[i] = prediction.getNumActiveFeatures(); 140 Map<String, Label> distribution = prediction.getOutputScores(); 141 for (Label label : this.getTopLabels(distribution)) { 142 paths.add(new Path(label, label.getScore(), null)); 143 } 144 } else { 145 // for later instances, find the best previous path for each label 146 Map<Label, Path> maxPaths = new HashMap<>(); 147 for (Path path : paths) { 148 Example<Label> clonedExample = example.copy(); 149 List<Label> previousLabels = new ArrayList<>(path.labels); 150 List<Feature> labelFeatures = extractFeatures(previousLabels); 151 clonedExample.addAll(labelFeatures); 152 Prediction<Label> prediction = this.model.predict(clonedExample); 153 // TODO this isn't quite correct as it includes label features. 154 numUsed[i] = prediction.getNumActiveFeatures(); 155 Map<String, Label> distribution = prediction.getOutputScores(); 156 157 for (Label label : this.getTopLabels(distribution)) { 158 double labelScore = label.getScore(); 159 double score = this.scoreAggregation == ScoreAggregation.ADD ? path.score + labelScore : path.score * labelScore; 160 Path maxPath = maxPaths.get(label); 161 if (maxPath == null || score > maxPath.score) { 162 maxPaths.put(label, new Path(label, score, path)); 163 } 164 } 165 } 166 paths = maxPaths.values(); 167 } 168 i++; 169 } 170 171 Path maxPath = Collections.max(paths); 172 173 ArrayList<Prediction<Label>> output = new ArrayList<>(); 174 175 for (int j = 0; j < examples.size(); j++) { 176 Example<Label> e = examples.get(j); 177 output.add(new Prediction<>(maxPath.labels.get(j), numUsed[j], e)); 178 } 179 180 return output; 181 } 182 183 protected List<Label> getTopLabels(Map<String, Label> distribution) { 184 return getTopLabels(distribution, this.stackSize); 185 } 186 187 protected static List<Label> getTopLabels(Map<String, Label> distribution, int stackSize) { 188 return distribution.values().stream().sorted(Comparator.comparingDouble(Label::getScore).reversed()).limit(stackSize) 189 .collect(Collectors.toList()); 190 // get just the labels that fit within the stack 191 } 192 193 private static class Path implements Comparable<Path> { 194 195 public final double score; 196 197 public final Path parent; 198 199 public final List<Label> labels; 200 201 public Path(Label label, double score, Path parent) { 202 this.score = score; 203 this.parent = parent; 204 this.labels = new ArrayList<>(); 205 if (this.parent != null) { 206 this.labels.addAll(this.parent.labels); 207 } 208 this.labels.add(label); 209 } 210 211 @Override 212 public int compareTo(Path that) { 213 return Double.compare(this.score, that.score); 214 } 215 216 } 217 218 public int getStackSize() { 219 return stackSize; 220 } 221 222 public ScoreAggregation getScoreAggregation() { 223 return scoreAggregation; 224 } 225 226 @Override 227 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 228 return model.getTopFeatures(n); 229 } 230 231}