001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.crf;
018
019import org.tribuo.math.la.DenseMatrix;
020import org.tribuo.math.la.DenseVector;
021
022import java.util.Arrays;
023
024/**
025 * A collection of helper methods for performing training and inference in a CRF.
026 */
027public final class ChainHelper {
028
029    private ChainHelper() { }
030
031    /**
032     * Runs belief propagation on a linear chain CRF. Uses the
033     * linear predictions for each token and the label transition probabilities.
034     * @param scores Tuple containing the label-label transition matrix, and the per token label scores.
035     * @return Tuple containing the normalising constant, the forward values and the backward values.
036     */
037    public static ChainBPResults beliefPropagation(ChainCliqueValues scores) {
038        int numLabels = scores.transitionValues.getDimension1Size();
039        DenseMatrix markovScores = scores.transitionValues;
040        DenseVector[] localScores = scores.localValues;
041        DenseVector[] alphas = new DenseVector[localScores.length];
042        DenseVector[] betas = new DenseVector[localScores.length];
043        for (int i = 0; i < localScores.length; i++) {
044            alphas[i] = localScores[i].copy();
045            betas[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
046        }
047        //
048        // Forward pass
049        double[] tmpArray = new double[numLabels];
050        for (int i = 1; i < localScores.length; i++) {
051            DenseVector curAlpha = alphas[i];
052            DenseVector prevAlpha = alphas[i - 1];
053            for (int vi = 0; vi < numLabels; vi++) {
054                for (int vj = 0; vj < numLabels; vj++) {
055                    tmpArray[vj] = markovScores.get(vj,vi) + prevAlpha.get(vj);
056                }
057                curAlpha.add(vi,sumLogProbs(tmpArray));
058            }
059        }
060        //
061        // Backward pass
062        betas[betas.length-1].fill(0.0);
063        for (int i = localScores.length - 2; i >= 0; i--) {
064            DenseVector curBeta = betas[i];
065            DenseVector prevBeta = betas[i + 1];
066            DenseVector prevLocalScore = localScores[i + 1];
067            for (int vi = 0; vi < numLabels; vi++) {
068                for (int vj = 0; vj < numLabels; vj++) {
069                    tmpArray[vj] = markovScores.get(vi,vj) + prevBeta.get(vj) + prevLocalScore.get(vj);
070                }
071                curBeta.set(vi,sumLogProbs(tmpArray));
072            }
073        }
074        double logZ = sumLogProbs(alphas[alphas.length-1]);
075        return new ChainBPResults(logZ, alphas, betas, scores);
076    }
077
078    /**
079     * Runs constrained belief propagation on a linear chain CRF. Uses the
080     * linear predictions for each token and the label transition probabilities.
081     * <p>
082     * See:
083     * <pre>
084     * "Confidence Estimation for Information Extraction",
085     * A. Culotta and A. McCallum
086     * Proceedings of HLT-NAACL 2004: Short Papers, 2004.
087     * </pre>
088     * @param scores Tuple containing the label-label transition matrix, and the per token label scores.
089     * @param constraints An array of integers, representing the label constraints. -1 signifies no constraint, otherwise it's the label id.
090     * @return The normalization constant for this constrained run.
091     */
092    public static double constrainedBeliefPropagation(ChainCliqueValues scores, int[] constraints) {
093        int numLabels = scores.transitionValues.getDimension1Size();
094        DenseMatrix markovScores = scores.transitionValues;
095        DenseVector[] localScores = scores.localValues;
096        if (localScores.length != constraints.length) {
097            throw new IllegalArgumentException("Must have the same number of constraints as tokens");
098        }
099        DenseVector[] alphas = new DenseVector[localScores.length];
100        for (int i = 0; i < localScores.length; i++) {
101            alphas[i] = localScores[i].copy();
102        }
103        //
104        // Forward pass
105        double[] tmpArray = new double[numLabels];
106        for (int i = 1; i < localScores.length; i++) {
107            DenseVector curAlpha = alphas[i];
108            DenseVector prevAlpha = alphas[i - 1];
109            for (int vi = 0; vi < numLabels; vi++) {
110                if ((constraints[i] == -1) || (constraints[i] == vi)) {
111                    // if unconstrained or path conforms to constraints
112                    for (int vj = 0; vj < numLabels; vj++) {
113                        tmpArray[vj] = markovScores.get(vj,vi) + prevAlpha.get(vj);
114                    }
115                    curAlpha.add(vi,sumLogProbs(tmpArray));
116                } else {
117                    // Path is outside constraints, set to zero as alpha is initialised with the local scores.
118                    curAlpha.set(vi,Double.NEGATIVE_INFINITY);
119                }
120            }
121        }
122        return sumLogProbs(alphas[alphas.length-1]);
123    }
124
125    /**
126     * Runs Viterbi on a linear chain CRF. Uses the
127     * linear predictions for each token and the label transition probabilities.
128     * @param scores Tuple containing the label-label transition matrix, and the per token label scores.
129     * @return Tuple containing the score of the maximum path and the maximum predicted label per token.
130     */
131    public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
132        DenseMatrix markovScores = scores.transitionValues;
133        DenseVector[] localScores = scores.localValues;
134        int numLabels = markovScores.getDimension1Size();
135        DenseVector[] costs = new DenseVector[scores.localValues.length];
136        int[][] backPointers = new int[scores.localValues.length][];
137        for (int i = 0; i < scores.localValues.length; i++) {
138            costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
139            backPointers[i] = new int[numLabels];
140            Arrays.fill(backPointers[i],-1);
141        }
142        costs[0].setElements(localScores[0]);
143        for (int i = 1; i < scores.localValues.length; i++) {
144            DenseVector curLocalScores = localScores[i];
145            DenseVector curCost = costs[i];
146            int[] curBackPointers = backPointers[i];
147            DenseVector prevCost = costs[i - 1];
148            for (int vi = 0; vi < numLabels; vi++) {
149                double maxScore = Double.NEGATIVE_INFINITY;
150                int maxIndex = -1;
151                double curLocalScore = curLocalScores.get(vi);
152
153                for (int vj = 0; vj < numLabels; vj++) {
154                    double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore;
155                    if (curScore > maxScore) {
156                        maxScore = curScore;
157                        maxIndex = vj;
158                    }
159                }
160                curCost.set(vi,maxScore);
161                if (maxIndex < 0) {
162                    maxIndex = 0;
163                }
164                curBackPointers[vi] = maxIndex;
165            }
166        }
167        int[] mapValues = new int[scores.localValues.length];
168        mapValues[mapValues.length - 1] = costs[costs.length-1].indexOfMax();
169        for (int j = mapValues.length - 2; j >= 0; j--) {
170            mapValues[j] = backPointers[j + 1][mapValues[j + 1]];
171        }
172        return new ChainViterbiResults(costs[costs.length-1].maxValue(), mapValues, scores);
173    }
174
175    /**
176     * Sums the log probabilities. Must be updated in concert with {@link ChainHelper#sumLogProbs(double[])}.
177     * @param input A {@link DenseVector} of log probabilities.
178     * @return log sum exp input[i].
179     */
180    public static double sumLogProbs(DenseVector input) {
181        double LOG_TOLERANCE = 30.0;
182
183        double maxValue = input.get(0);
184        int maxIdx = 0;
185        for (int i = 1; i < input.size(); i++) {
186            double value = input.get(i);
187            if (value > maxValue) {
188                maxValue = value;
189                maxIdx = i;
190            }
191        }
192        if (maxValue == Double.NEGATIVE_INFINITY) {
193            return maxValue;
194        }
195
196        boolean anyAdded = false;
197        double intermediate = 0.0;
198        double cutoff = maxValue - LOG_TOLERANCE;
199        for (int i = 0; i < input.size(); i++) {
200            double value = input.get(i);
201            if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) {
202                anyAdded = true;
203                intermediate += Math.exp(value - maxValue);
204            }
205        }
206        if (anyAdded) {
207            return maxValue + Math.log1p(intermediate);
208        } else {
209            return maxValue;
210        }
211    }
212
213    /**
214     * Sums the log probabilities. Must be updated in concert with {@link ChainHelper#sumLogProbs(DenseVector)}.
215     * @param input A double array of log probabilities.
216     * @return log sum exp input[i].
217     */
218    public static double sumLogProbs(double[] input) {
219        double LOG_TOLERANCE = 30.0;
220
221        double maxValue = input[0];
222        int maxIdx = 0;
223        for (int i = 1; i < input.length; i++) {
224            double value = input[i];
225            if (value > maxValue) {
226                maxValue = value;
227                maxIdx = i;
228            }
229        }
230        if (maxValue == Double.NEGATIVE_INFINITY) {
231            return maxValue;
232        }
233
234        boolean anyAdded = false;
235        double intermediate = 0.0;
236        double cutoff = maxValue - LOG_TOLERANCE;
237        for (int i = 0; i < input.length; i++) {
238            if (input[i] >= cutoff && i != maxIdx && !Double.isInfinite(input[i])) {
239                anyAdded = true;
240                intermediate += Math.exp(input[i] - maxValue);
241            }
242        }
243        if (anyAdded) {
244            return maxValue + Math.log1p(intermediate);
245        } else {
246            return maxValue;
247        }
248    }
249
250    /**
251     * Belief Propagation results. One day it'll be a record, but not today.
252     */
253    public static class ChainBPResults {
254        public final double logZ;
255        public final DenseVector[] alphas;
256        public final DenseVector[] betas;
257        public final ChainCliqueValues scores;
258
259        public ChainBPResults(double logZ, DenseVector[] alphas, DenseVector[] betas, ChainCliqueValues scores) {
260            this.logZ = logZ;
261            this.alphas = alphas;
262            this.betas = betas;
263            this.scores = scores;
264        }
265    }
266
267    /**
268     * Clique scores within a chain. One day it'll be a record, but not today.
269     */
270    public static class ChainCliqueValues {
271        public final DenseVector[] localValues;
272        public final DenseMatrix transitionValues;
273
274        public ChainCliqueValues(DenseVector[] localValues, DenseMatrix transitionValues) {
275            this.localValues = localValues;
276            this.transitionValues = transitionValues;
277        }
278    }
279
280    /**
281     * Viterbi output from a linear chain. One day it'll be a record, but not today.
282     */
283    public static class ChainViterbiResults {
284        public final double mapScore;
285        public final int[] mapValues;
286        public final ChainCliqueValues scores;
287
288        public ChainViterbiResults(double mapScore, int[] mapValues, ChainCliqueValues scores) {
289            this.mapScore = mapScore;
290            this.mapValues = mapValues;
291            this.scores = scores;
292        }
293    }
294}