001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.crf; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.WeightedExamples; 025import org.tribuo.classification.Label; 026import org.tribuo.classification.sgd.Util; 027import org.tribuo.math.StochasticGradientOptimiser; 028import org.tribuo.math.la.SparseVector; 029import org.tribuo.math.la.Tensor; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.provenance.TrainerProvenance; 032import org.tribuo.provenance.impl.TrainerProvenanceImpl; 033import org.tribuo.sequence.SequenceDataset; 034import org.tribuo.sequence.SequenceExample; 035import org.tribuo.sequence.SequenceTrainer; 036 037import java.time.OffsetDateTime; 038import java.util.Map; 039import java.util.SplittableRandom; 040import java.util.logging.Logger; 041 042/** 043 * A trainer for CRFs using SGD. Modelled after FACTORIE's trainer for CRFs. 044 * <p> 045 * See: 046 * <pre> 047 * Lafferty J, McCallum A, Pereira FC. 048 * "Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data" 049 * Proceedings of the 18th International Conference on Machine Learning 2001 (ICML 2001). 050 * </pre> 051 */ 052public class CRFTrainer implements SequenceTrainer<Label>, WeightedExamples { 053 private static final Logger logger = Logger.getLogger(CRFTrainer.class.getName()); 054 055 @Config(mandatory = true,description="The gradient optimiser to use.") 056 private StochasticGradientOptimiser optimiser; 057 058 @Config(description="The number of gradient descent epochs.") 059 private int epochs = 5; 060 061 @Config(description="Log values after this many updates.") 062 private int loggingInterval = -1; 063 064 @Config(description="Minibatch size in SGD.") 065 private int minibatchSize = 1; 066 067 @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.") 068 private long seed; 069 070 @Config(description="Shuffle the data before each epoch. Only turn off for debugging.") 071 private boolean shuffle = true; 072 073 private SplittableRandom rng; 074 075 private int trainInvocationCounter; 076 077 /** 078 * Creates a CRFTrainer which uses SGD to learn the parameters. 079 * @param optimiser The gradient optimiser to use. 080 * @param epochs The number of SGD epochs (complete passes through the training data). 081 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 082 * @param minibatchSize The size of the minibatches used to aggregate gradients. 083 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 084 */ 085 public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) { 086 this.optimiser = optimiser; 087 this.epochs = epochs; 088 this.loggingInterval = loggingInterval; 089 this.minibatchSize = minibatchSize; 090 this.seed = seed; 091 postConfig(); 092 } 093 094 /** 095 * Sets the minibatch size to 1. 096 * @param optimiser The gradient optimiser to use. 097 * @param epochs The number of SGD epochs (complete passes through the training data). 098 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 099 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 100 */ 101 public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) { 102 this(optimiser,epochs,loggingInterval,1,seed); 103 } 104 105 /** 106 * Sets the minibatch size to 1 and the logging interval to 100. 107 * @param optimiser The gradient optimiser to use. 108 * @param epochs The number of SGD epochs (complete passes through the training data). 109 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 110 */ 111 public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed) { 112 this(optimiser,epochs,100,1,seed); 113 } 114 115 /** 116 * For olcut. 117 */ 118 private CRFTrainer() { } 119 120 @Override 121 public synchronized void postConfig() { 122 this.rng = new SplittableRandom(seed); 123 } 124 125 /** 126 * Turn on or off shuffling of examples. 127 * <p> 128 * This isn't exposed in the constructor as it defaults to on. 129 * This method should be used for debugging. 130 * @param shuffle If true shuffle the examples, if false leave them in their current order. 131 */ 132 public void setShuffle(boolean shuffle) { 133 this.shuffle = shuffle; 134 } 135 136 @Override 137 public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) { 138 if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) { 139 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 140 } 141 // Creates a new RNG, adds one to the invocation count, generates a local optimiser. 142 SplittableRandom localRNG; 143 TrainerProvenance trainerProvenance; 144 StochasticGradientOptimiser localOptimiser; 145 synchronized(this) { 146 localRNG = rng.split(); 147 localOptimiser = optimiser.copy(); 148 trainerProvenance = getProvenance(); 149 trainInvocationCounter++; 150 } 151 ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo(); 152 ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap(); 153 SparseVector[][] sgdFeatures = new SparseVector[sequenceExamples.size()][]; 154 int[][] sgdLabels = new int[sequenceExamples.size()][]; 155 double[] weights = new double[sequenceExamples.size()]; 156 int n = 0; 157 for (SequenceExample<Label> example : sequenceExamples) { 158 weights[n] = example.getWeight(); 159 Pair<int[],SparseVector[]> pair = CRFModel.convert(example,featureIDMap,labelIDMap); 160 sgdFeatures[n] = pair.getB(); 161 sgdLabels[n] = pair.getA(); 162 n++; 163 } 164 logger.info(String.format("Training SGD CRF with %d examples", n)); 165 166 CRFParameters crfParameters = new CRFParameters(featureIDMap.size(),labelIDMap.size()); 167 168 localOptimiser.initialise(crfParameters); 169 double loss = 0.0; 170 int iteration = 0; 171 172 for (int i = 0; i < epochs; i++) { 173 if (shuffle) { 174 Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG); 175 } 176 if (minibatchSize == 1) { 177 /* 178 * Special case a minibatch of size 1. Directly updates the parameters after each 179 * example rather than aggregating. 180 */ 181 for (int j = 0; j < sgdFeatures.length; j++) { 182 Pair<Double,Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j],sgdLabels[j]); 183 loss += output.getA()*weights[j]; 184 185 //Update the gradient with the current learning rates 186 Tensor[] updates = localOptimiser.step(output.getB(),weights[j]); 187 188 //Apply the update to the current parameters. 189 crfParameters.update(updates); 190 191 iteration++; 192 if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) { 193 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 194 loss = 0.0; 195 } 196 } 197 } else { 198 Tensor[][] gradients = new Tensor[minibatchSize][]; 199 for (int j = 0; j < sgdFeatures.length; j += minibatchSize) { 200 double tempWeight = 0.0; 201 int curSize = 0; 202 //Aggregate the gradient updates for each example in the minibatch 203 for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) { 204 Pair<Double,Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j],sgdLabels[j]); 205 loss += output.getA()*weights[k]; 206 tempWeight += weights[k]; 207 208 gradients[k-j] = output.getB(); 209 curSize++; 210 } 211 //Merge the values into a single gradient update 212 Tensor[] updates = crfParameters.merge(gradients,curSize); 213 for (Tensor update : updates) { 214 update.scaleInPlace(minibatchSize); 215 } 216 tempWeight /= minibatchSize; 217 //Update the gradient with the current learning rates 218 updates = localOptimiser.step(updates,tempWeight); 219 //Apply the gradient. 220 crfParameters.update(updates); 221 222 iteration++; 223 if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) { 224 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 225 loss = 0.0; 226 } 227 } 228 } 229 } 230 localOptimiser.finalise(); 231 //public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) { 232 ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(),OffsetDateTime.now(),sequenceExamples.getProvenance(),trainerProvenance,runProvenance); 233 CRFModel model = new CRFModel("crf-sgd-model",provenance,featureIDMap,labelIDMap,crfParameters); 234 localOptimiser.reset(); 235 return model; 236 } 237 238 @Override 239 public int getInvocationCount() { 240 return trainInvocationCounter; 241 } 242 243 @Override 244 public String toString() { 245 return "CRFTrainer(optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")"; 246 } 247 248 @Override 249 public TrainerProvenance getProvenance() { 250 return new TrainerProvenanceImpl(this); 251 } 252}