001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.crf; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.ImmutableFeatureMap; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Output; 024import org.tribuo.Prediction; 025import org.tribuo.classification.Label; 026import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel; 027import org.tribuo.math.la.DenseVector; 028import org.tribuo.math.la.SparseVector; 029import org.tribuo.math.la.Tensor; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.sequence.SequenceExample; 032 033import java.util.ArrayList; 034import java.util.Collections; 035import java.util.Comparator; 036import java.util.HashMap; 037import java.util.LinkedHashMap; 038import java.util.List; 039import java.util.Map; 040import java.util.PriorityQueue; 041import java.util.logging.Logger; 042 043import static org.tribuo.Model.BIAS_FEATURE; 044 045/** 046 * An inference time model for a CRF trained using SGD. 047 * <p> 048 * Can be switched to use belief propagation, or constrained BP, at test time instead of the standard Viterbi. 049 * <p> 050 * See: 051 * <pre> 052 * Lafferty J, McCallum A, Pereira FC. 053 * "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data" 054 * Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001). 055 * </pre> 056 */ 057public class CRFModel extends ConfidencePredictingSequenceModel { 058 private static final Logger logger = Logger.getLogger(CRFModel.class.getName()); 059 private static final long serialVersionUID = 2L; 060 061 private final CRFParameters parameters; 062 063 /** 064 * The type of subsequence level confidence to predict. 065 */ 066 public enum ConfidenceType { 067 /** 068 * No confidence predction. 069 */ 070 NONE, 071 /** 072 * Belief Propagation 073 */ 074 MULTIPLY, 075 /** 076 * Constrained Belief Propagation from "Confidence Estimation for Information Extraction" Culotta and McCallum 2004. 077 */ 078 CONSTRAINED_BP 079 } 080 081 private ConfidenceType confidenceType; 082 083 CRFModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, CRFParameters parameters) { 084 super(name, description, featureIDMap, labelIDMap); 085 this.parameters = parameters; 086 this.confidenceType = ConfidenceType.NONE; 087 } 088 089 /** 090 * Sets the inference method used for confidence prediction. 091 * If CONSTRAINED_BP uses the constrained belief propagation algorithm from Culotta and McCallum 2004, 092 * if MULTIPLY multiplies the maximum marginal for each token, if NONE uses Viterbi. 093 * 094 * @param type Enum specifying the confidence type. 095 */ 096 public void setConfidenceType(ConfidenceType type) { 097 this.confidenceType = type; 098 } 099 100 /** 101 * Get a copy of the weights for feature {@code featureID}. 102 * @param featureID The feature ID. 103 * @return The per class weights. 104 */ 105 public DenseVector getFeatureWeights(int featureID) { 106 if (featureID < 0 || featureID >= featureIDMap.size()) { 107 logger.warning("Unknown feature"); 108 return new DenseVector(0); 109 } else { 110 return parameters.getFeatureWeights(featureID); 111 } 112 } 113 114 /** 115 * Get a copy of the weights for feature named {@code featureName}. 116 * @param featureName The feature name. 117 * @return The per class weights. 118 */ 119 public DenseVector getFeatureWeights(String featureName) { 120 int id = featureIDMap.getID(featureName); 121 if (id > -1) { 122 return getFeatureWeights(featureIDMap.getID(featureName)); 123 } else { 124 logger.warning("Unknown feature"); 125 return new DenseVector(0); 126 } 127 } 128 129 @Override 130 public List<Prediction<Label>> predict(SequenceExample<Label> example) { 131 SparseVector[] features = convert(example,featureIDMap); 132 List<Prediction<Label>> output = new ArrayList<>(); 133 if (confidenceType == ConfidenceType.MULTIPLY) { 134 DenseVector[] marginals = parameters.predictMarginals(features); 135 136 for (int i = 0; i < marginals.length; i++) { 137 double maxScore = Double.NEGATIVE_INFINITY; 138 Label maxLabel = null; 139 Map<String,Label> predMap = new LinkedHashMap<>(); 140 for (int j = 0; j < marginals[i].size(); j++) { 141 String labelName = outputIDMap.getOutput(j).getLabel(); 142 Label label = new Label(labelName, marginals[i].get(j)); 143 predMap.put(labelName, label); 144 if (label.getScore() > maxScore) { 145 maxScore = label.getScore(); 146 maxLabel = label; 147 } 148 } 149 output.add(new Prediction<>(maxLabel, predMap, features[i].numActiveElements(), example.get(i), true)); 150 } 151 } else { 152 int[] predLabels = parameters.predict(features); 153 154 for (int i = 0; i < predLabels.length; i++) { 155 output.add(new Prediction<>(outputIDMap.getOutput(predLabels[i]),features[i].numActiveElements(),example.get(i))); 156 } 157 } 158 159 return output; 160 } 161 162 @Override 163 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 164 int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n; 165 166 // Uses a standard comparator rather than a Math.abs comparator as it's a log-linear model 167 // so nothing is actually negative. 168 Comparator<Pair<String,Double>> comparator = Comparator.comparing(Pair::getB); 169 170 // 171 // Use a priority queue to find the top N features. 172 int numClasses = outputIDMap.size(); 173 int numFeatures = featureIDMap.size(); 174 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 175 for (int i = 0; i < numClasses; i++) { 176 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 177 178 for (int j = 0; j < numFeatures; j++) { 179 Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), parameters.getWeight(i,j)); 180 181 if (q.size() < maxFeatures) { 182 q.offer(curr); 183 } else if (comparator.compare(curr, q.peek()) > 0) { 184 q.poll(); 185 q.offer(curr); 186 } 187 } 188 Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, parameters.getBias(i)); 189 190 if (q.size() < maxFeatures) { 191 q.offer(curr); 192 } else if (comparator.compare(curr, q.peek()) > 0) { 193 q.poll(); 194 q.offer(curr); 195 } 196 ArrayList<Pair<String,Double>> b = new ArrayList<>(); 197 while (q.size() > 0) { 198 b.add(q.poll()); 199 } 200 201 Collections.reverse(b); 202 map.put(outputIDMap.getOutput(i).getLabel(), b); 203 } 204 return map; 205 } 206 207 @Override 208 public <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences) { 209 if (confidenceType == ConfidenceType.CONSTRAINED_BP) { 210 List<Chunk> chunks = new ArrayList<>(); 211 for(Subsequence subsequence : subsequences) { 212 int[] ids = new int[subsequence.length()]; 213 for(int i=0; i<ids.length; i++) { 214 ids[i] = outputIDMap.getID(predictions.get(i+subsequence.begin).getOutput()); 215 } 216 chunks.add(new Chunk(subsequence.begin, ids)); 217 } 218 return scoreChunks(example, chunks); 219 } else { 220 return ConfidencePredictingSequenceModel.multiplyWeights(predictions, subsequences); 221 } 222 } 223 224 /** 225 * Scores the chunks using constrained belief propagation. 226 * @param example The example to score. 227 * @param chunks The predicted chunks. 228 * @return The scores. 229 */ 230 public List<Double> scoreChunks(SequenceExample<Label> example, List<Chunk> chunks) { 231 SparseVector[] features = convert(example,featureIDMap); 232 return parameters.predictConfidenceUsingCBP(features,chunks); 233 } 234 235 /** 236 * Generates a human readable string containing all the weights in this model. 237 * @return A string containing all the weight values. 238 */ 239 public String generateWeightsString() { 240 StringBuilder buffer = new StringBuilder(); 241 242 Tensor[] weights = parameters.get(); 243 244 buffer.append("Biases = "); 245 buffer.append(weights[0].toString()); 246 buffer.append('\n'); 247 248 buffer.append("Feature-Label weights = \n"); 249 buffer.append(weights[1].toString()); 250 buffer.append('\n'); 251 252 buffer.append("Label-Label weights = \n"); 253 buffer.append(weights[2].toString()); 254 buffer.append('\n'); 255 256 return buffer.toString(); 257 } 258 259 /** 260 * Converts a {@link SequenceExample} into an array of {@link SparseVector}s suitable for CRF prediction. 261 * @param example The sequence example to convert 262 * @param featureIDMap The feature id map, used to discover the number of features. 263 * @param <T> The type parameter of the sequence example. 264 * @return An array of {@link SparseVector}. 265 */ 266 public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) { 267 int length = example.size(); 268 if (length == 0) { 269 throw new IllegalArgumentException("SequenceExample is empty, " + example.toString()); 270 } 271 SparseVector[] features = new SparseVector[length]; 272 int i = 0; 273 for (Example<T> e : example) { 274 features[i] = SparseVector.createSparseVector(e,featureIDMap,false); 275 if (features[i].numActiveElements() == 0) { 276 throw new IllegalArgumentException("No features found in Example " + e.toString()); 277 } 278 i++; 279 } 280 return features; 281 } 282 283 /** 284 * Converts a {@link SequenceExample} into an array of {@link SparseVector}s and labels suitable for CRF prediction. 285 * @param example The sequence example to convert 286 * @param featureIDMap The feature id map, used to discover the number of features. 287 * @param labelIDMap The label id map, used to get the index of the labels. 288 * @return A {@link Pair} of an int array of labels and an array of {@link SparseVector}. 289 */ 290 public static Pair<int[],SparseVector[]> convert(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) { 291 int length = example.size(); 292 if (length == 0) { 293 throw new IllegalArgumentException("SequenceExample is empty, " + example.toString()); 294 } 295 int[] labels = new int[length]; 296 SparseVector[] features = new SparseVector[length]; 297 int i = 0; 298 for (Example<Label> e : example) { 299 labels[i] = labelIDMap.getID(e.getOutput()); 300 features[i] = SparseVector.createSparseVector(e,featureIDMap,false); 301 if (features[i].numActiveElements() == 0) { 302 throw new IllegalArgumentException("No features found in Example " + e.toString()); 303 } 304 i++; 305 } 306 return new Pair<>(labels,features); 307 } 308}