001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.crf;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.WeightedExamples;
025import org.tribuo.classification.Label;
026import org.tribuo.classification.sgd.Util;
027import org.tribuo.math.StochasticGradientOptimiser;
028import org.tribuo.math.la.SparseVector;
029import org.tribuo.math.la.Tensor;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.provenance.TrainerProvenance;
032import org.tribuo.provenance.impl.TrainerProvenanceImpl;
033import org.tribuo.sequence.SequenceDataset;
034import org.tribuo.sequence.SequenceExample;
035import org.tribuo.sequence.SequenceTrainer;
036
037import java.time.OffsetDateTime;
038import java.util.Map;
039import java.util.SplittableRandom;
040import java.util.logging.Logger;
041
042/**
043 * A trainer for CRFs using SGD. Modelled after FACTORIE's trainer for CRFs.
044 * <p>
045 * See:
046 * <pre>
047 * Lafferty J, McCallum A, Pereira FC.
048 * "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data"
049 * Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001).
050 * </pre>
051 */
052public class CRFTrainer implements SequenceTrainer<Label>, WeightedExamples {
053    private static final Logger logger = Logger.getLogger(CRFTrainer.class.getName());
054
055    @Config(mandatory = true,description="The gradient optimiser to use.")
056    private StochasticGradientOptimiser optimiser;
057
058    @Config(description="The number of gradient descent epochs.")
059    private int epochs = 5;
060
061    @Config(description="Log values after this many updates.")
062    private int loggingInterval = -1;
063
064    @Config(description="Minibatch size in SGD.")
065    private int minibatchSize = 1;
066
067    @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.")
068    private long seed;
069
070    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
071    private boolean shuffle = true;
072
073    private SplittableRandom rng;
074
075    private int trainInvocationCounter;
076
077    /**
078     * Creates a CRFTrainer which uses SGD to learn the parameters.
079     * @param optimiser The gradient optimiser to use.
080     * @param epochs The number of SGD epochs (complete passes through the training data).
081     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
082     * @param minibatchSize The size of the minibatches used to aggregate gradients.
083     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
084     */
085    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) {
086        this.optimiser = optimiser;
087        this.epochs = epochs;
088        this.loggingInterval = loggingInterval;
089        this.minibatchSize = minibatchSize;
090        this.seed = seed;
091        postConfig();
092    }
093
094    /**
095     * Sets the minibatch size to 1.
096     * @param optimiser The gradient optimiser to use.
097     * @param epochs The number of SGD epochs (complete passes through the training data).
098     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
099     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
100     */
101    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) {
102        this(optimiser,epochs,loggingInterval,1,seed);
103    }
104
105    /**
106     * Sets the minibatch size to 1 and the logging interval to 100.
107     * @param optimiser The gradient optimiser to use.
108     * @param epochs The number of SGD epochs (complete passes through the training data).
109     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
110     */
111    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed) {
112        this(optimiser,epochs,100,1,seed);
113    }
114
115    /**
116     * For olcut.
117     */
118    private CRFTrainer() { }
119
120    @Override
121    public synchronized void postConfig() {
122        this.rng = new SplittableRandom(seed);
123    }
124
125    /**
126     * Turn on or off shuffling of examples.
127     * <p>
128     * This isn't exposed in the constructor as it defaults to on.
129     * This method should be used for debugging.
130     * @param shuffle If true shuffle the examples, if false leave them in their current order.
131     */
132    public void setShuffle(boolean shuffle) {
133        this.shuffle = shuffle;
134    }
135
136    @Override
137    public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
138        if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
139            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
140        }
141        // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
142        SplittableRandom localRNG;
143        TrainerProvenance trainerProvenance;
144        StochasticGradientOptimiser localOptimiser;
145        synchronized(this) {
146            localRNG = rng.split();
147            localOptimiser = optimiser.copy();
148            trainerProvenance = getProvenance();
149            trainInvocationCounter++;
150        }
151        ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
152        ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
153        SparseVector[][] sgdFeatures = new SparseVector[sequenceExamples.size()][];
154        int[][] sgdLabels = new int[sequenceExamples.size()][];
155        double[] weights = new double[sequenceExamples.size()];
156        int n = 0;
157        for (SequenceExample<Label> example : sequenceExamples) {
158            weights[n] = example.getWeight();
159            Pair<int[],SparseVector[]> pair = CRFModel.convert(example,featureIDMap,labelIDMap);
160            sgdFeatures[n] = pair.getB();
161            sgdLabels[n] = pair.getA();
162            n++;
163        }
164        logger.info(String.format("Training SGD CRF with %d examples", n));
165
166        CRFParameters crfParameters = new CRFParameters(featureIDMap.size(),labelIDMap.size());
167
168        localOptimiser.initialise(crfParameters);
169        double loss = 0.0;
170        int iteration = 0;
171
172        for (int i = 0; i < epochs; i++) {
173            if (shuffle) {
174                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
175            }
176            if (minibatchSize == 1) {
177                /*
178                 * Special case a minibatch of size 1. Directly updates the parameters after each
179                 * example rather than aggregating.
180                 */
181                for (int j = 0; j < sgdFeatures.length; j++) {
182                    Pair<Double,Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j],sgdLabels[j]);
183                    loss += output.getA()*weights[j];
184
185                    //Update the gradient with the current learning rates
186                    Tensor[] updates = localOptimiser.step(output.getB(),weights[j]);
187
188                    //Apply the update to the current parameters.
189                    crfParameters.update(updates);
190
191                    iteration++;
192                    if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
193                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
194                        loss = 0.0;
195                    }
196                }
197            } else {
198                Tensor[][] gradients = new Tensor[minibatchSize][];
199                for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
200                    double tempWeight = 0.0;
201                    int curSize = 0;
202                    //Aggregate the gradient updates for each example in the minibatch
203                    for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) {
204                        Pair<Double,Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j],sgdLabels[j]);
205                        loss += output.getA()*weights[k];
206                        tempWeight += weights[k];
207
208                        gradients[k-j] = output.getB();
209                        curSize++;
210                    }
211                    //Merge the values into a single gradient update
212                    Tensor[] updates = crfParameters.merge(gradients,curSize);
213                    for (Tensor update : updates) {
214                        update.scaleInPlace(minibatchSize);
215                    }
216                    tempWeight /= minibatchSize;
217                    //Update the gradient with the current learning rates
218                    updates = localOptimiser.step(updates,tempWeight);
219                    //Apply the gradient.
220                    crfParameters.update(updates);
221
222                    iteration++;
223                    if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
224                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
225                        loss = 0.0;
226                    }
227                }
228            }
229        }
230        localOptimiser.finalise();
231        //public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
232        ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(),OffsetDateTime.now(),sequenceExamples.getProvenance(),trainerProvenance,runProvenance);
233        CRFModel model = new CRFModel("crf-sgd-model",provenance,featureIDMap,labelIDMap,crfParameters);
234        localOptimiser.reset();
235        return model;
236    }
237
238    @Override
239    public int getInvocationCount() {
240        return trainInvocationCounter;
241    }
242
243    @Override
244    public String toString() {
245        return "CRFTrainer(optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")";
246    }
247
248    @Override
249    public TrainerProvenance getProvenance() {
250        return new TrainerProvenanceImpl(this);
251    }
252}