001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.crf; 018 019import org.tribuo.math.la.DenseMatrix; 020import org.tribuo.math.la.DenseVector; 021 022import java.util.Arrays; 023 024/** 025 * A collection of helper methods for performing training and inference in a CRF. 026 */ 027public final class ChainHelper { 028 029 private ChainHelper() { } 030 031 /** 032 * Runs belief propagation on a linear chain CRF. Uses the 033 * linear predictions for each token and the label transition probabilities. 034 * @param scores Tuple containing the label-label transition matrix, and the per token label scores. 035 * @return Tuple containing the normalising constant, the forward values and the backward values. 036 */ 037 public static ChainBPResults beliefPropagation(ChainCliqueValues scores) { 038 int numLabels = scores.transitionValues.getDimension1Size(); 039 DenseMatrix markovScores = scores.transitionValues; 040 DenseVector[] localScores = scores.localValues; 041 DenseVector[] alphas = new DenseVector[localScores.length]; 042 DenseVector[] betas = new DenseVector[localScores.length]; 043 for (int i = 0; i < localScores.length; i++) { 044 alphas[i] = localScores[i].copy(); 045 betas[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY); 046 } 047 // 048 // Forward pass 049 double[] tmpArray = new double[numLabels]; 050 for (int i = 1; i < localScores.length; i++) { 051 DenseVector curAlpha = alphas[i]; 052 DenseVector prevAlpha = alphas[i - 1]; 053 for (int vi = 0; vi < numLabels; vi++) { 054 for (int vj = 0; vj < numLabels; vj++) { 055 tmpArray[vj] = markovScores.get(vj,vi) + prevAlpha.get(vj); 056 } 057 curAlpha.add(vi,sumLogProbs(tmpArray)); 058 } 059 } 060 // 061 // Backward pass 062 betas[betas.length-1].fill(0.0); 063 for (int i = localScores.length - 2; i >= 0; i--) { 064 DenseVector curBeta = betas[i]; 065 DenseVector prevBeta = betas[i + 1]; 066 DenseVector prevLocalScore = localScores[i + 1]; 067 for (int vi = 0; vi < numLabels; vi++) { 068 for (int vj = 0; vj < numLabels; vj++) { 069 tmpArray[vj] = markovScores.get(vi,vj) + prevBeta.get(vj) + prevLocalScore.get(vj); 070 } 071 curBeta.set(vi,sumLogProbs(tmpArray)); 072 } 073 } 074 double logZ = sumLogProbs(alphas[alphas.length-1]); 075 return new ChainBPResults(logZ, alphas, betas, scores); 076 } 077 078 /** 079 * Runs constrained belief propagation on a linear chain CRF. Uses the 080 * linear predictions for each token and the label transition probabilities. 081 * <p> 082 * See: 083 * <pre> 084 * "Confidence Estimation for Information Extraction", 085 * A. Culotta and A. McCallum 086 * Proceedings of HLT-NAACL 2004: Short Papers, 2004. 087 * </pre> 088 * @param scores Tuple containing the label-label transition matrix, and the per token label scores. 089 * @param constraints An array of integers, representing the label constraints. -1 signifies no constraint, otherwise it's the label id. 090 * @return The normalization constant for this constrained run. 091 */ 092 public static double constrainedBeliefPropagation(ChainCliqueValues scores, int[] constraints) { 093 int numLabels = scores.transitionValues.getDimension1Size(); 094 DenseMatrix markovScores = scores.transitionValues; 095 DenseVector[] localScores = scores.localValues; 096 if (localScores.length != constraints.length) { 097 throw new IllegalArgumentException("Must have the same number of constraints as tokens"); 098 } 099 DenseVector[] alphas = new DenseVector[localScores.length]; 100 for (int i = 0; i < localScores.length; i++) { 101 alphas[i] = localScores[i].copy(); 102 } 103 // 104 // Forward pass 105 double[] tmpArray = new double[numLabels]; 106 for (int i = 1; i < localScores.length; i++) { 107 DenseVector curAlpha = alphas[i]; 108 DenseVector prevAlpha = alphas[i - 1]; 109 for (int vi = 0; vi < numLabels; vi++) { 110 if ((constraints[i] == -1) || (constraints[i] == vi)) { 111 // if unconstrained or path conforms to constraints 112 for (int vj = 0; vj < numLabels; vj++) { 113 tmpArray[vj] = markovScores.get(vj,vi) + prevAlpha.get(vj); 114 } 115 curAlpha.add(vi,sumLogProbs(tmpArray)); 116 } else { 117 // Path is outside constraints, set to zero as alpha is initialised with the local scores. 118 curAlpha.set(vi,Double.NEGATIVE_INFINITY); 119 } 120 } 121 } 122 return sumLogProbs(alphas[alphas.length-1]); 123 } 124 125 /** 126 * Runs Viterbi on a linear chain CRF. Uses the 127 * linear predictions for each token and the label transition probabilities. 128 * @param scores Tuple containing the label-label transition matrix, and the per token label scores. 129 * @return Tuple containing the score of the maximum path and the maximum predicted label per token. 130 */ 131 public static ChainViterbiResults viterbi(ChainCliqueValues scores) { 132 DenseMatrix markovScores = scores.transitionValues; 133 DenseVector[] localScores = scores.localValues; 134 int numLabels = markovScores.getDimension1Size(); 135 DenseVector[] costs = new DenseVector[scores.localValues.length]; 136 int[][] backPointers = new int[scores.localValues.length][]; 137 for (int i = 0; i < scores.localValues.length; i++) { 138 costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY); 139 backPointers[i] = new int[numLabels]; 140 Arrays.fill(backPointers[i],-1); 141 } 142 costs[0].setElements(localScores[0]); 143 for (int i = 1; i < scores.localValues.length; i++) { 144 DenseVector curLocalScores = localScores[i]; 145 DenseVector curCost = costs[i]; 146 int[] curBackPointers = backPointers[i]; 147 DenseVector prevCost = costs[i - 1]; 148 for (int vi = 0; vi < numLabels; vi++) { 149 double maxScore = Double.NEGATIVE_INFINITY; 150 int maxIndex = -1; 151 double curLocalScore = curLocalScores.get(vi); 152 153 for (int vj = 0; vj < numLabels; vj++) { 154 double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore; 155 if (curScore > maxScore) { 156 maxScore = curScore; 157 maxIndex = vj; 158 } 159 } 160 curCost.set(vi,maxScore); 161 if (maxIndex < 0) { 162 maxIndex = 0; 163 } 164 curBackPointers[vi] = maxIndex; 165 } 166 } 167 int[] mapValues = new int[scores.localValues.length]; 168 mapValues[mapValues.length - 1] = costs[costs.length-1].indexOfMax(); 169 for (int j = mapValues.length - 2; j >= 0; j--) { 170 mapValues[j] = backPointers[j + 1][mapValues[j + 1]]; 171 } 172 return new ChainViterbiResults(costs[costs.length-1].maxValue(), mapValues, scores); 173 } 174 175 /** 176 * Sums the log probabilities. Must be updated in concert with {@link ChainHelper#sumLogProbs(double[])}. 177 * @param input A {@link DenseVector} of log probabilities. 178 * @return log sum exp input[i]. 179 */ 180 public static double sumLogProbs(DenseVector input) { 181 double LOG_TOLERANCE = 30.0; 182 183 double maxValue = input.get(0); 184 int maxIdx = 0; 185 for (int i = 1; i < input.size(); i++) { 186 double value = input.get(i); 187 if (value > maxValue) { 188 maxValue = value; 189 maxIdx = i; 190 } 191 } 192 if (maxValue == Double.NEGATIVE_INFINITY) { 193 return maxValue; 194 } 195 196 boolean anyAdded = false; 197 double intermediate = 0.0; 198 double cutoff = maxValue - LOG_TOLERANCE; 199 for (int i = 0; i < input.size(); i++) { 200 double value = input.get(i); 201 if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) { 202 anyAdded = true; 203 intermediate += Math.exp(value - maxValue); 204 } 205 } 206 if (anyAdded) { 207 return maxValue + Math.log1p(intermediate); 208 } else { 209 return maxValue; 210 } 211 } 212 213 /** 214 * Sums the log probabilities. Must be updated in concert with {@link ChainHelper#sumLogProbs(DenseVector)}. 215 * @param input A double array of log probabilities. 216 * @return log sum exp input[i]. 217 */ 218 public static double sumLogProbs(double[] input) { 219 double LOG_TOLERANCE = 30.0; 220 221 double maxValue = input[0]; 222 int maxIdx = 0; 223 for (int i = 1; i < input.length; i++) { 224 double value = input[i]; 225 if (value > maxValue) { 226 maxValue = value; 227 maxIdx = i; 228 } 229 } 230 if (maxValue == Double.NEGATIVE_INFINITY) { 231 return maxValue; 232 } 233 234 boolean anyAdded = false; 235 double intermediate = 0.0; 236 double cutoff = maxValue - LOG_TOLERANCE; 237 for (int i = 0; i < input.length; i++) { 238 if (input[i] >= cutoff && i != maxIdx && !Double.isInfinite(input[i])) { 239 anyAdded = true; 240 intermediate += Math.exp(input[i] - maxValue); 241 } 242 } 243 if (anyAdded) { 244 return maxValue + Math.log1p(intermediate); 245 } else { 246 return maxValue; 247 } 248 } 249 250 /** 251 * Belief Propagation results. One day it'll be a record, but not today. 252 */ 253 public static class ChainBPResults { 254 public final double logZ; 255 public final DenseVector[] alphas; 256 public final DenseVector[] betas; 257 public final ChainCliqueValues scores; 258 259 public ChainBPResults(double logZ, DenseVector[] alphas, DenseVector[] betas, ChainCliqueValues scores) { 260 this.logZ = logZ; 261 this.alphas = alphas; 262 this.betas = betas; 263 this.scores = scores; 264 } 265 } 266 267 /** 268 * Clique scores within a chain. One day it'll be a record, but not today. 269 */ 270 public static class ChainCliqueValues { 271 public final DenseVector[] localValues; 272 public final DenseMatrix transitionValues; 273 274 public ChainCliqueValues(DenseVector[] localValues, DenseMatrix transitionValues) { 275 this.localValues = localValues; 276 this.transitionValues = transitionValues; 277 } 278 } 279 280 /** 281 * Viterbi output from a linear chain. One day it'll be a record, but not today. 282 */ 283 public static class ChainViterbiResults { 284 public final double mapScore; 285 public final int[] mapValues; 286 public final ChainCliqueValues scores; 287 288 public ChainViterbiResults(double mapScore, int[] mapValues, ChainCliqueValues scores) { 289 this.mapScore = mapScore; 290 this.mapValues = mapValues; 291 this.scores = scores; 292 } 293 } 294}