001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.crf;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.ImmutableFeatureMap;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Output;
024import org.tribuo.Prediction;
025import org.tribuo.classification.Label;
026import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel;
027import org.tribuo.math.la.DenseVector;
028import org.tribuo.math.la.SparseVector;
029import org.tribuo.math.la.Tensor;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.sequence.SequenceExample;
032
033import java.util.ArrayList;
034import java.util.Collections;
035import java.util.Comparator;
036import java.util.HashMap;
037import java.util.LinkedHashMap;
038import java.util.List;
039import java.util.Map;
040import java.util.PriorityQueue;
041import java.util.logging.Logger;
042
043import static org.tribuo.Model.BIAS_FEATURE;
044
045/**
046 * An inference time model for a CRF trained using SGD.
047 * <p>
048 * Can be switched to use belief propagation, or constrained BP, at test time instead of the standard Viterbi.
049 * <p>
050 * See:
051 * <pre>
052 * Lafferty J, McCallum A, Pereira FC.
053 * "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data"
054 * Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
055 * </pre>
056 */
057public class CRFModel extends ConfidencePredictingSequenceModel {
058    private static final Logger logger = Logger.getLogger(CRFModel.class.getName());
059    private static final long serialVersionUID = 2L;
060
061    private final CRFParameters parameters;
062
063    /**
064     * The type of subsequence level confidence to predict.
065     */
066    public enum ConfidenceType {
067        /**
068         * No confidence predction.
069         */
070        NONE,
071        /**
072         * Belief Propagation
073         */
074        MULTIPLY,
075        /**
076         * Constrained Belief Propagation from "Confidence Estimation for Information Extraction" Culotta and McCallum 2004.
077         */
078        CONSTRAINED_BP
079    }
080
081    private ConfidenceType confidenceType;
082
083    CRFModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, CRFParameters parameters) {
084        super(name, description, featureIDMap, labelIDMap);
085        this.parameters = parameters;
086        this.confidenceType = ConfidenceType.NONE;
087    }
088
089    /**
090     * Sets the inference method used for confidence prediction.
091     * If CONSTRAINED_BP uses the constrained belief propagation algorithm from Culotta and McCallum 2004,
092     * if MULTIPLY multiplies the maximum marginal for each token, if NONE uses Viterbi.
093     *
094     * @param type Enum specifying the confidence type.
095     */
096    public void setConfidenceType(ConfidenceType type) {
097        this.confidenceType = type;
098    }
099
100    /**
101     * Get a copy of the weights for feature {@code featureID}.
102     * @param featureID The feature ID.
103     * @return The per class weights.
104     */
105    public DenseVector getFeatureWeights(int featureID) {
106        if (featureID < 0 || featureID >= featureIDMap.size()) {
107            logger.warning("Unknown feature");
108            return new DenseVector(0);
109        } else {
110            return parameters.getFeatureWeights(featureID);
111        }
112    }
113
114    /**
115     * Get a copy of the weights for feature named {@code featureName}.
116     * @param featureName The feature name.
117     * @return The per class weights.
118     */
119    public DenseVector getFeatureWeights(String featureName) {
120        int id = featureIDMap.getID(featureName);
121        if (id > -1) {
122            return getFeatureWeights(featureIDMap.getID(featureName));
123        } else {
124            logger.warning("Unknown feature");
125            return new DenseVector(0);
126        }
127    }
128
129    @Override
130    public List<Prediction<Label>> predict(SequenceExample<Label> example) {
131        SparseVector[] features = convert(example,featureIDMap);
132        List<Prediction<Label>> output = new ArrayList<>();
133        if (confidenceType == ConfidenceType.MULTIPLY) {
134            DenseVector[] marginals = parameters.predictMarginals(features);
135
136            for (int i = 0; i < marginals.length; i++) {
137                double maxScore = Double.NEGATIVE_INFINITY;
138                Label maxLabel = null;
139                Map<String,Label> predMap = new LinkedHashMap<>();
140                for (int j = 0; j < marginals[i].size(); j++) {
141                    String labelName = outputIDMap.getOutput(j).getLabel();
142                    Label label = new Label(labelName, marginals[i].get(j));
143                    predMap.put(labelName, label);
144                    if (label.getScore() > maxScore) {
145                        maxScore = label.getScore();
146                        maxLabel = label;
147                    }
148                }
149                output.add(new Prediction<>(maxLabel, predMap, features[i].numActiveElements(), example.get(i), true));
150            }
151        } else {
152            int[] predLabels = parameters.predict(features);
153
154            for (int i = 0; i < predLabels.length; i++) {
155                output.add(new Prediction<>(outputIDMap.getOutput(predLabels[i]),features[i].numActiveElements(),example.get(i)));
156            }
157        }
158
159        return output;
160    }
161
162    @Override
163    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
164        int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n;
165
166        // Uses a standard comparator rather than a Math.abs comparator as it's a log-linear model
167        // so nothing is actually negative.
168        Comparator<Pair<String,Double>> comparator = Comparator.comparing(Pair::getB);
169
170        //
171        // Use a priority queue to find the top N features.
172        int numClasses = outputIDMap.size();
173        int numFeatures = featureIDMap.size();
174        Map<String, List<Pair<String,Double>>> map = new HashMap<>();
175        for (int i = 0; i < numClasses; i++) {
176            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
177
178            for (int j = 0; j < numFeatures; j++) {
179                Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), parameters.getWeight(i,j));
180
181                if (q.size() < maxFeatures) {
182                    q.offer(curr);
183                } else if (comparator.compare(curr, q.peek()) > 0) {
184                    q.poll();
185                    q.offer(curr);
186                }
187            }
188            Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, parameters.getBias(i));
189
190            if (q.size() < maxFeatures) {
191                q.offer(curr);
192            } else if (comparator.compare(curr, q.peek()) > 0) {
193                q.poll();
194                q.offer(curr);
195            }
196            ArrayList<Pair<String,Double>> b = new ArrayList<>();
197            while (q.size() > 0) {
198                b.add(q.poll());
199            }
200
201            Collections.reverse(b);
202            map.put(outputIDMap.getOutput(i).getLabel(), b);
203        }
204        return map;
205    }
206
207    @Override
208    public <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences) {
209        if (confidenceType == ConfidenceType.CONSTRAINED_BP) {
210            List<Chunk> chunks = new ArrayList<>();
211            for(Subsequence subsequence : subsequences) {
212                int[] ids = new int[subsequence.length()];
213                for(int i=0; i<ids.length; i++) {
214                    ids[i] = outputIDMap.getID(predictions.get(i+subsequence.begin).getOutput());
215                }
216                chunks.add(new Chunk(subsequence.begin, ids));
217            }
218            return scoreChunks(example, chunks);
219        } else {
220            return ConfidencePredictingSequenceModel.multiplyWeights(predictions, subsequences);
221        }
222    }
223
224    /**
225     * Scores the chunks using constrained belief propagation.
226     * @param example The example to score.
227     * @param chunks The predicted chunks.
228     * @return The scores.
229     */
230    public List<Double> scoreChunks(SequenceExample<Label> example, List<Chunk> chunks) {
231        SparseVector[] features = convert(example,featureIDMap);
232        return parameters.predictConfidenceUsingCBP(features,chunks);
233    }
234
235    /**
236     * Generates a human readable string containing all the weights in this model.
237     * @return A string containing all the weight values.
238     */
239    public String generateWeightsString() {
240        StringBuilder buffer = new StringBuilder();
241
242        Tensor[] weights = parameters.get();
243
244        buffer.append("Biases = ");
245        buffer.append(weights[0].toString());
246        buffer.append('\n');
247
248        buffer.append("Feature-Label weights = \n");
249        buffer.append(weights[1].toString());
250        buffer.append('\n');
251
252        buffer.append("Label-Label weights = \n");
253        buffer.append(weights[2].toString());
254        buffer.append('\n');
255
256        return buffer.toString();
257    }
258
259    /**
260     * Converts a {@link SequenceExample} into an array of {@link SparseVector}s suitable for CRF prediction.
261     * @param example The sequence example to convert
262     * @param featureIDMap The feature id map, used to discover the number of features.
263     * @param <T> The type parameter of the sequence example.
264     * @return An array of {@link SparseVector}.
265     */
266    public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) {
267        int length = example.size();
268        if (length == 0) {
269            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
270        }
271        SparseVector[] features = new SparseVector[length];
272        int i = 0;
273        for (Example<T> e : example) {
274            features[i] = SparseVector.createSparseVector(e,featureIDMap,false);
275            if (features[i].numActiveElements() == 0) {
276                throw new IllegalArgumentException("No features found in Example " + e.toString());
277            }
278            i++;
279        }
280        return features;
281    }
282
283    /**
284     * Converts a {@link SequenceExample} into an array of {@link SparseVector}s and labels suitable for CRF prediction.
285     * @param example The sequence example to convert
286     * @param featureIDMap The feature id map, used to discover the number of features.
287     * @param labelIDMap The label id map, used to get the index of the labels.
288     * @return A {@link Pair} of an int array of labels and an array of {@link SparseVector}.
289     */
290    public static Pair<int[],SparseVector[]> convert(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
291        int length = example.size();
292        if (length == 0) {
293            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
294        }
295        int[] labels = new int[length];
296        SparseVector[] features = new SparseVector[length];
297        int i = 0;
298        for (Example<Label> e : example) {
299            labels[i] = labelIDMap.getID(e.getOutput());
300            features[i] = SparseVector.createSparseVector(e,featureIDMap,false);
301            if (features[i].numActiveElements() == 0) {
302                throw new IllegalArgumentException("No features found in Example " + e.toString());
303            }
304            i++;
305        }
306        return new Pair<>(labels,features);
307    }
308}