001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.crf;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.math.Parameters;
021import org.tribuo.math.la.DenseMatrix;
022import org.tribuo.math.la.DenseSparseMatrix;
023import org.tribuo.math.la.DenseVector;
024import org.tribuo.math.la.Matrix;
025import org.tribuo.math.la.SparseVector;
026import org.tribuo.math.la.Tensor;
027import org.tribuo.math.util.HeapMerger;
028import org.tribuo.math.util.Merger;
029
030import java.io.Serializable;
031import java.util.ArrayList;
032import java.util.Arrays;
033import java.util.List;
034
035/**
036 * A {@link Parameters} for training a CRF using SGD.
037 */
038public class CRFParameters implements Parameters, Serializable {
039    private static final long serialVersionUID = 1L;
040
041    private final int numLabels;
042    private final int numFeatures;
043
044    private static final Merger merger = new HeapMerger();
045
046    /**
047     * This variable is an array with 3 elements corresponding to the three weight matrices.
048     */
049    private Tensor[] weights;
050
051    private DenseVector biases; //weights[0];
052    private DenseMatrix featureLabelWeights; //weights[1];
053    private DenseMatrix labelLabelWeights; //weights[2];
054
055    CRFParameters(int numFeatures, int numLabels) {
056        this.biases = new DenseVector(numLabels);
057        this.featureLabelWeights = new DenseMatrix(numLabels,numFeatures);
058        this.labelLabelWeights = new DenseMatrix(numLabels,numLabels);
059        this.weights = new Tensor[3];
060        weights[0] = biases;
061        weights[1] = featureLabelWeights;
062        weights[2] = labelLabelWeights;
063        this.numLabels = numLabels;
064        this.numFeatures = numFeatures;
065    }
066
067    public DenseVector getFeatureWeights(int id) {
068        return featureLabelWeights.getColumn(id);
069    }
070
071    public double getBias(int id) {
072        return biases.get(id);
073    }
074
075    public double getWeight(int labelID, int featureID) {
076        return featureLabelWeights.get(labelID, featureID);
077    }
078
079    /**
080     * Generate the local scores (i.e., the linear classifier for each token).
081     * @param features An array of {@link SparseVector}s, one per token.
082     * @return An array of DenseVectors representing the scores per label for each token.
083     */
084    public DenseVector[] getLocalScores(SparseVector[] features) {
085        DenseVector[] localScores = new DenseVector[features.length];
086        for (int i = 0; i < features.length; i++) {
087            DenseVector scores = featureLabelWeights.leftMultiply(features[i]);
088            scores.intersectAndAddInPlace(biases);
089            localScores[i] = scores;
090        }
091        return localScores;
092    }
093
094    /**
095     * Generates the local scores and tuples them with the label - label transition weights.
096     * @param features The per token {@link SparseVector} of features.
097     * @return A tuple containing the array of {@link DenseVector} scores and the label - label transition weights.
098     */
099    public ChainHelper.ChainCliqueValues getCliqueValues(SparseVector[] features) {
100        DenseVector[] localScores = getLocalScores(features);
101        return new ChainHelper.ChainCliqueValues(localScores, labelLabelWeights);
102    }
103
104    /**
105     * Generate a prediction using Viterbi.
106     * @param features The per token {@link SparseVector} of features.
107     * @return An int array giving the predicted label per token.
108     */
109    public int[] predict(SparseVector[] features) {
110        ChainHelper.ChainViterbiResults result = ChainHelper.viterbi(getCliqueValues(features));
111        return result.mapValues;
112    }
113
114    /**
115     * Generate a prediction using Belief Propagation.
116     * @param features The per token {@link SparseVector} of features.
117     * @return A {@link DenseVector} per token containing the marginal distribution over labels.
118     */
119    public DenseVector[] predictMarginals(SparseVector[] features) {
120        ChainHelper.ChainBPResults result = ChainHelper.beliefPropagation(getCliqueValues(features));
121        DenseVector[] marginals = new DenseVector[features.length];
122        for (int i = 0; i < features.length; i++) {
123            marginals[i] = result.alphas[i].add(result.betas[i]);
124            marginals[i].expNormalize(result.logZ);
125        }
126        return marginals;
127    }
128
129    /**
130     * This predicts per chunk confidence using the constrained forward backward algorithm from
131     * Culotta and McCallum 2004.
132     * <p>
133     * Runs one pass of BP to get the normalizing constant, and then a further chunks.size() passes
134     * of constrained forward backward.
135     * @param features The per token {@link SparseVector} of features.
136     * @param chunks A list of extracted chunks to pin constrain the labels to.
137     * @return A list containing the confidence value for each chunk.
138     */
139    public List<Double> predictConfidenceUsingCBP(SparseVector[] features, List<Chunk> chunks) {
140        ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(features);
141        ChainHelper.ChainBPResults bpResult = ChainHelper.beliefPropagation(cliqueValues);
142        double bpLogZ = bpResult.logZ;
143
144        int[] constraints = new int[features.length];
145
146        List<Double> output = new ArrayList<>();
147        for (Chunk chunk : chunks) {
148            Arrays.fill(constraints,-1);
149            chunk.unpack(constraints);
150            double chunkScore = ChainHelper.constrainedBeliefPropagation(cliqueValues,constraints);
151            output.add(Math.exp(chunkScore - bpLogZ));
152        }
153
154        return output;
155    }
156
157    /**
158     * Generates predictions based on the input features and labels, then scores those predictions to
159     * produce a loss for the example and a gradient update.
160     * @param features The per token {@link SparseVector} of features.
161     * @param labels The per token ground truth labels.
162     * @return A {@link Pair} containing the loss for this example and the associated gradient.
163     */
164    public Pair<Double, Tensor[]> valueAndGradient(SparseVector[] features, int[] labels) {
165        ChainHelper.ChainCliqueValues scores = getCliqueValues(features);
166        // Infer the marginal distribution over labels for each token.
167        ChainHelper.ChainBPResults bpResults = ChainHelper.beliefPropagation(scores);
168        double logZ = bpResults.logZ;
169        DenseVector[] alphas = bpResults.alphas;
170        DenseVector[] betas = bpResults.betas;
171
172        //Calculate the gradients for the parameters.
173        Tensor[] gradient = new Tensor[3];
174        DenseSparseMatrix[] featureGradients = new DenseSparseMatrix[features.length];
175        gradient[0] = new DenseVector(biases.size());
176        Matrix transGradient = new DenseMatrix(numLabels, numLabels);
177        gradient[2] = transGradient;
178        double score = -logZ;
179        for (int i = 0; i < features.length; i++) {
180            int curLabel = labels[i];
181            // Increment the loss based on the score for the true label.
182            DenseVector curLocalScores = scores.localValues[i];
183            score += curLocalScores.get(curLabel);
184            // Generate the predicted local marginal from the BP run.
185            DenseVector curAlpha = alphas[i];
186            DenseVector curBeta = betas[i];
187            DenseVector localMarginal = curAlpha.add(curBeta);
188            // Generate the gradient for the biases based on the true label and predicted label.
189            localMarginal.expNormalize(logZ);
190            localMarginal.scaleInPlace(-1.0);
191            localMarginal.add(curLabel,1.0);
192            gradient[0].intersectAndAddInPlace(localMarginal);
193            // Generate the gradient for the feature - label weights
194            featureGradients[i] = (DenseSparseMatrix) localMarginal.outer(features[i]);
195            // If the sequence has more than one token generate the gradient for the label - label transitions.
196            if (i >= 1) {
197                DenseVector prevAlpha = alphas[i - 1];
198                for (int ii = 0; ii < numLabels; ii++) {
199                    for (int jj = 0; jj < numLabels; jj++) {
200                        double update = -Math.exp(prevAlpha.get(ii) + labelLabelWeights.get(ii,jj) + curBeta.get(jj) + curLocalScores.get(jj) - logZ);
201                        transGradient.add(ii, jj, update);
202                    }
203                }
204                int prevLabel = labels[i-1];
205                // Increment the loss based on the transition from the previous predicted label to the true label.
206                score += (labelLabelWeights.get(prevLabel,curLabel));
207                transGradient.add(prevLabel, curLabel, 1.0);
208            }
209        }
210        // Merge together all the sparse feature - label gradients.
211        gradient[1] = merger.merge(featureGradients);
212
213        return new Pair<>(score,gradient);
214    }
215
216    /**
217     * Returns a 3 element {@link Tensor} array.
218     *
219     * The first element is a {@link DenseVector} of label biases.
220     * The second element is a {@link DenseMatrix} of feature-label weights.
221     * The third element is a {@link DenseMatrix} of label-label transition weights.
222     * @return A {@link Tensor} array.
223     */
224    @Override
225    public Tensor[] getEmptyCopy() {
226        Tensor[] output = new Tensor[3];
227        output[0] = new DenseVector(biases.size());
228        output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(),featureLabelWeights.getDimension2Size());
229        output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(),labelLabelWeights.getDimension2Size());
230        return output;
231    }
232
233    @Override
234    public Tensor[] get() {
235        return weights;
236    }
237
238    @Override
239    public void set(Tensor[] newWeights) {
240        if (newWeights.length == weights.length) {
241            weights = newWeights;
242            biases = (DenseVector) weights[0];
243            featureLabelWeights = (DenseMatrix) weights[1];
244            labelLabelWeights = (DenseMatrix) weights[2];
245        }
246    }
247
248    @Override
249    public void update(Tensor[] gradients) {
250        for (int i = 0; i < gradients.length; i++) {
251            weights[i].intersectAndAddInPlace(gradients[i]);
252        }
253    }
254
255    @Override
256    public Tensor[] merge(Tensor[][] gradients, int size) {
257        DenseVector biasUpdate = new DenseVector(biases.size());
258        DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
259        DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(),labelLabelWeights.getDimension2Size());
260        for (int j = 0; j < updates.length; j++) {
261            biasUpdate.intersectAndAddInPlace(gradients[j][0]);
262            updates[j] = (DenseSparseMatrix) gradients[j][1];
263            labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
264        }
265
266        DenseSparseMatrix featureLabelUpdate = merger.merge(updates);
267
268        return new Tensor[]{biasUpdate,featureLabelUpdate,labelLabelUpdate};
269    }
270}