001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.crf; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.math.Parameters; 021import org.tribuo.math.la.DenseMatrix; 022import org.tribuo.math.la.DenseSparseMatrix; 023import org.tribuo.math.la.DenseVector; 024import org.tribuo.math.la.Matrix; 025import org.tribuo.math.la.SparseVector; 026import org.tribuo.math.la.Tensor; 027import org.tribuo.math.util.HeapMerger; 028import org.tribuo.math.util.Merger; 029 030import java.io.Serializable; 031import java.util.ArrayList; 032import java.util.Arrays; 033import java.util.List; 034 035/** 036 * A {@link Parameters} for training a CRF using SGD. 037 */ 038public class CRFParameters implements Parameters, Serializable { 039 private static final long serialVersionUID = 1L; 040 041 private final int numLabels; 042 private final int numFeatures; 043 044 private static final Merger merger = new HeapMerger(); 045 046 /** 047 * This variable is an array with 3 elements corresponding to the three weight matrices. 048 */ 049 private Tensor[] weights; 050 051 private DenseVector biases; //weights[0]; 052 private DenseMatrix featureLabelWeights; //weights[1]; 053 private DenseMatrix labelLabelWeights; //weights[2]; 054 055 CRFParameters(int numFeatures, int numLabels) { 056 this.biases = new DenseVector(numLabels); 057 this.featureLabelWeights = new DenseMatrix(numLabels,numFeatures); 058 this.labelLabelWeights = new DenseMatrix(numLabels,numLabels); 059 this.weights = new Tensor[3]; 060 weights[0] = biases; 061 weights[1] = featureLabelWeights; 062 weights[2] = labelLabelWeights; 063 this.numLabels = numLabels; 064 this.numFeatures = numFeatures; 065 } 066 067 public DenseVector getFeatureWeights(int id) { 068 return featureLabelWeights.getColumn(id); 069 } 070 071 public double getBias(int id) { 072 return biases.get(id); 073 } 074 075 public double getWeight(int labelID, int featureID) { 076 return featureLabelWeights.get(labelID, featureID); 077 } 078 079 /** 080 * Generate the local scores (i.e., the linear classifier for each token). 081 * @param features An array of {@link SparseVector}s, one per token. 082 * @return An array of DenseVectors representing the scores per label for each token. 083 */ 084 public DenseVector[] getLocalScores(SparseVector[] features) { 085 DenseVector[] localScores = new DenseVector[features.length]; 086 for (int i = 0; i < features.length; i++) { 087 DenseVector scores = featureLabelWeights.leftMultiply(features[i]); 088 scores.intersectAndAddInPlace(biases); 089 localScores[i] = scores; 090 } 091 return localScores; 092 } 093 094 /** 095 * Generates the local scores and tuples them with the label - label transition weights. 096 * @param features The per token {@link SparseVector} of features. 097 * @return A tuple containing the array of {@link DenseVector} scores and the label - label transition weights. 098 */ 099 public ChainHelper.ChainCliqueValues getCliqueValues(SparseVector[] features) { 100 DenseVector[] localScores = getLocalScores(features); 101 return new ChainHelper.ChainCliqueValues(localScores, labelLabelWeights); 102 } 103 104 /** 105 * Generate a prediction using Viterbi. 106 * @param features The per token {@link SparseVector} of features. 107 * @return An int array giving the predicted label per token. 108 */ 109 public int[] predict(SparseVector[] features) { 110 ChainHelper.ChainViterbiResults result = ChainHelper.viterbi(getCliqueValues(features)); 111 return result.mapValues; 112 } 113 114 /** 115 * Generate a prediction using Belief Propagation. 116 * @param features The per token {@link SparseVector} of features. 117 * @return A {@link DenseVector} per token containing the marginal distribution over labels. 118 */ 119 public DenseVector[] predictMarginals(SparseVector[] features) { 120 ChainHelper.ChainBPResults result = ChainHelper.beliefPropagation(getCliqueValues(features)); 121 DenseVector[] marginals = new DenseVector[features.length]; 122 for (int i = 0; i < features.length; i++) { 123 marginals[i] = result.alphas[i].add(result.betas[i]); 124 marginals[i].expNormalize(result.logZ); 125 } 126 return marginals; 127 } 128 129 /** 130 * This predicts per chunk confidence using the constrained forward backward algorithm from 131 * Culotta and McCallum 2004. 132 * <p> 133 * Runs one pass of BP to get the normalizing constant, and then a further chunks.size() passes 134 * of constrained forward backward. 135 * @param features The per token {@link SparseVector} of features. 136 * @param chunks A list of extracted chunks to pin constrain the labels to. 137 * @return A list containing the confidence value for each chunk. 138 */ 139 public List<Double> predictConfidenceUsingCBP(SparseVector[] features, List<Chunk> chunks) { 140 ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(features); 141 ChainHelper.ChainBPResults bpResult = ChainHelper.beliefPropagation(cliqueValues); 142 double bpLogZ = bpResult.logZ; 143 144 int[] constraints = new int[features.length]; 145 146 List<Double> output = new ArrayList<>(); 147 for (Chunk chunk : chunks) { 148 Arrays.fill(constraints,-1); 149 chunk.unpack(constraints); 150 double chunkScore = ChainHelper.constrainedBeliefPropagation(cliqueValues,constraints); 151 output.add(Math.exp(chunkScore - bpLogZ)); 152 } 153 154 return output; 155 } 156 157 /** 158 * Generates predictions based on the input features and labels, then scores those predictions to 159 * produce a loss for the example and a gradient update. 160 * @param features The per token {@link SparseVector} of features. 161 * @param labels The per token ground truth labels. 162 * @return A {@link Pair} containing the loss for this example and the associated gradient. 163 */ 164 public Pair<Double, Tensor[]> valueAndGradient(SparseVector[] features, int[] labels) { 165 ChainHelper.ChainCliqueValues scores = getCliqueValues(features); 166 // Infer the marginal distribution over labels for each token. 167 ChainHelper.ChainBPResults bpResults = ChainHelper.beliefPropagation(scores); 168 double logZ = bpResults.logZ; 169 DenseVector[] alphas = bpResults.alphas; 170 DenseVector[] betas = bpResults.betas; 171 172 //Calculate the gradients for the parameters. 173 Tensor[] gradient = new Tensor[3]; 174 DenseSparseMatrix[] featureGradients = new DenseSparseMatrix[features.length]; 175 gradient[0] = new DenseVector(biases.size()); 176 Matrix transGradient = new DenseMatrix(numLabels, numLabels); 177 gradient[2] = transGradient; 178 double score = -logZ; 179 for (int i = 0; i < features.length; i++) { 180 int curLabel = labels[i]; 181 // Increment the loss based on the score for the true label. 182 DenseVector curLocalScores = scores.localValues[i]; 183 score += curLocalScores.get(curLabel); 184 // Generate the predicted local marginal from the BP run. 185 DenseVector curAlpha = alphas[i]; 186 DenseVector curBeta = betas[i]; 187 DenseVector localMarginal = curAlpha.add(curBeta); 188 // Generate the gradient for the biases based on the true label and predicted label. 189 localMarginal.expNormalize(logZ); 190 localMarginal.scaleInPlace(-1.0); 191 localMarginal.add(curLabel,1.0); 192 gradient[0].intersectAndAddInPlace(localMarginal); 193 // Generate the gradient for the feature - label weights 194 featureGradients[i] = (DenseSparseMatrix) localMarginal.outer(features[i]); 195 // If the sequence has more than one token generate the gradient for the label - label transitions. 196 if (i >= 1) { 197 DenseVector prevAlpha = alphas[i - 1]; 198 for (int ii = 0; ii < numLabels; ii++) { 199 for (int jj = 0; jj < numLabels; jj++) { 200 double update = -Math.exp(prevAlpha.get(ii) + labelLabelWeights.get(ii,jj) + curBeta.get(jj) + curLocalScores.get(jj) - logZ); 201 transGradient.add(ii, jj, update); 202 } 203 } 204 int prevLabel = labels[i-1]; 205 // Increment the loss based on the transition from the previous predicted label to the true label. 206 score += (labelLabelWeights.get(prevLabel,curLabel)); 207 transGradient.add(prevLabel, curLabel, 1.0); 208 } 209 } 210 // Merge together all the sparse feature - label gradients. 211 gradient[1] = merger.merge(featureGradients); 212 213 return new Pair<>(score,gradient); 214 } 215 216 /** 217 * Returns a 3 element {@link Tensor} array. 218 * 219 * The first element is a {@link DenseVector} of label biases. 220 * The second element is a {@link DenseMatrix} of feature-label weights. 221 * The third element is a {@link DenseMatrix} of label-label transition weights. 222 * @return A {@link Tensor} array. 223 */ 224 @Override 225 public Tensor[] getEmptyCopy() { 226 Tensor[] output = new Tensor[3]; 227 output[0] = new DenseVector(biases.size()); 228 output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(),featureLabelWeights.getDimension2Size()); 229 output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(),labelLabelWeights.getDimension2Size()); 230 return output; 231 } 232 233 @Override 234 public Tensor[] get() { 235 return weights; 236 } 237 238 @Override 239 public void set(Tensor[] newWeights) { 240 if (newWeights.length == weights.length) { 241 weights = newWeights; 242 biases = (DenseVector) weights[0]; 243 featureLabelWeights = (DenseMatrix) weights[1]; 244 labelLabelWeights = (DenseMatrix) weights[2]; 245 } 246 } 247 248 @Override 249 public void update(Tensor[] gradients) { 250 for (int i = 0; i < gradients.length; i++) { 251 weights[i].intersectAndAddInPlace(gradients[i]); 252 } 253 } 254 255 @Override 256 public Tensor[] merge(Tensor[][] gradients, int size) { 257 DenseVector biasUpdate = new DenseVector(biases.size()); 258 DenseSparseMatrix[] updates = new DenseSparseMatrix[size]; 259 DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(),labelLabelWeights.getDimension2Size()); 260 for (int j = 0; j < updates.length; j++) { 261 biasUpdate.intersectAndAddInPlace(gradients[j][0]); 262 updates[j] = (DenseSparseMatrix) gradients[j][1]; 263 labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]); 264 } 265 266 DenseSparseMatrix featureLabelUpdate = merger.merge(updates); 267 268 return new Tensor[]{biasUpdate,featureLabelUpdate,labelLabelUpdate}; 269 } 270}