001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.linear; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Model; 027import org.tribuo.Trainer; 028import org.tribuo.WeightedExamples; 029import org.tribuo.classification.Label; 030import org.tribuo.classification.sgd.LabelObjective; 031import org.tribuo.classification.sgd.Util; 032import org.tribuo.classification.sgd.objectives.LogMulticlass; 033import org.tribuo.math.LinearParameters; 034import org.tribuo.math.StochasticGradientOptimiser; 035import org.tribuo.math.la.SGDVector; 036import org.tribuo.math.la.SparseVector; 037import org.tribuo.math.la.Tensor; 038import org.tribuo.math.optimisers.AdaGrad; 039import org.tribuo.provenance.ModelProvenance; 040import org.tribuo.provenance.TrainerProvenance; 041import org.tribuo.provenance.impl.TrainerProvenanceImpl; 042 043import java.time.OffsetDateTime; 044import java.util.Map; 045import java.util.SplittableRandom; 046import java.util.logging.Logger; 047 048/** 049 * A trainer for a linear model which uses SGD. 050 * <p> 051 * See: 052 * <pre> 053 * Bottou L. 054 * "Large-Scale Machine Learning with Stochastic Gradient Descent" 055 * Proceedings of COMPSTAT, 2010. 056 * </pre> 057 */ 058public class LinearSGDTrainer implements Trainer<Label>, WeightedExamples { 059 private static final Logger logger = Logger.getLogger(LinearSGDTrainer.class.getName()); 060 061 @Config(description="The classification objective function to use.") 062 private LabelObjective objective = new LogMulticlass(); 063 064 @Config(description="The gradient optimiser to use.") 065 private StochasticGradientOptimiser optimiser = new AdaGrad(1.0,0.1); 066 067 @Config(description="The number of gradient descent epochs.") 068 private int epochs = 5; 069 070 @Config(description="Log values after this many updates.") 071 private int loggingInterval = -1; 072 073 @Config(description="Minibatch size in SGD.") 074 private int minibatchSize = 1; 075 076 @Config(description="Seed for the RNG used to shuffle elements.") 077 private long seed = Trainer.DEFAULT_SEED; 078 079 @Config(description="Shuffle the data before each epoch. Only turn off for debugging.") 080 private boolean shuffle = true; 081 082 private SplittableRandom rng; 083 084 private int trainInvocationCounter; 085 086 /** 087 * Constructs an SGD trainer for a linear model. 088 * @param objective The objective function to optimise. 089 * @param optimiser The gradient optimiser to use. 090 * @param epochs The number of epochs (complete passes through the training data). 091 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 092 * @param minibatchSize The size of any minibatches. 093 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 094 */ 095 public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) { 096 this.objective = objective; 097 this.optimiser = optimiser; 098 this.epochs = epochs; 099 this.loggingInterval = loggingInterval; 100 this.minibatchSize = minibatchSize; 101 this.seed = seed; 102 postConfig(); 103 } 104 105 /** 106 * Sets the minibatch size to 1. 107 * @param objective The objective function to optimise. 108 * @param optimiser The gradient optimiser to use. 109 * @param epochs The number of epochs (complete passes through the training data). 110 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 111 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 112 */ 113 public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) { 114 this(objective,optimiser,epochs,loggingInterval,1,seed); 115 } 116 117 /** 118 * Sets the minibatch size to 1 and the logging interval to 1000. 119 * @param objective The objective function to optimise. 120 * @param optimiser The gradient optimiser to use. 121 * @param epochs The number of epochs (complete passes through the training data). 122 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 123 */ 124 public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) { 125 this(objective,optimiser,epochs,1000,1,seed); 126 } 127 128 /** 129 * For olcut. 130 */ 131 private LinearSGDTrainer() { } 132 133 @Override 134 public synchronized void postConfig() { 135 this.rng = new SplittableRandom(seed); 136 } 137 138 /** 139 * Turn on or off shuffling of examples. 140 * <p> 141 * This isn't exposed in the constructor as it defaults to on. 142 * This method should only be used for debugging. 143 * @param shuffle If true shuffle the examples, if false leave them in their current order. 144 */ 145 public void setShuffle(boolean shuffle) { 146 this.shuffle = shuffle; 147 } 148 149 @Override 150 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 151 if (examples.getOutputInfo().getUnknownCount() > 0) { 152 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 153 } 154 // Creates a new RNG, adds one to the invocation count, generates a local optimiser. 155 TrainerProvenance trainerProvenance; 156 SplittableRandom localRNG; 157 StochasticGradientOptimiser localOptimiser; 158 synchronized(this) { 159 localRNG = rng.split(); 160 localOptimiser = optimiser.copy(); 161 trainerProvenance = getProvenance(); 162 trainInvocationCounter++; 163 } 164 ImmutableOutputInfo<Label> labelIDMap = examples.getOutputIDInfo(); 165 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 166 SparseVector[] sgdFeatures = new SparseVector[examples.size()]; 167 int[] sgdLabels = new int[examples.size()]; 168 double[] weights = new double[examples.size()]; 169 int n = 0; 170 for (Example<Label> example : examples) { 171 weights[n] = example.getWeight(); 172 sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true); 173 sgdLabels[n] = labelIDMap.getID(example.getOutput()); 174 n++; 175 } 176 logger.info(String.format("Training SGD classifier with %d examples", n)); 177 logger.info("Labels - " + labelIDMap.toReadableString()); 178 179 // featureIDMap.size()+1 adds the bias feature. 180 LinearParameters linearParameters = new LinearParameters(featureIDMap.size()+1,labelIDMap.size()); 181 182 localOptimiser.initialise(linearParameters); 183 double loss = 0.0; 184 int iteration = 0; 185 186 for (int i = 0; i < epochs; i++) { 187 if (shuffle) { 188 Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG); 189 } 190 if (minibatchSize == 1) { 191 for (int j = 0; j < sgdFeatures.length; j++) { 192 SGDVector pred = linearParameters.predict(sgdFeatures[j]); 193 Pair<Double,SGDVector> output = objective.valueAndGradient(sgdLabels[j],pred); 194 loss += output.getA()*weights[j]; 195 196 Tensor[] updates = localOptimiser.step(linearParameters.gradients(output,sgdFeatures[j]),weights[j]); 197 linearParameters.update(updates); 198 199 iteration++; 200 if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) { 201 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 202 loss = 0.0; 203 } 204 } 205 } else { 206 Tensor[][] gradients = new Tensor[minibatchSize][]; 207 for (int j = 0; j < sgdFeatures.length; j += minibatchSize) { 208 double tempWeight = 0.0; 209 int curSize = 0; 210 for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) { 211 SGDVector pred = linearParameters.predict(sgdFeatures[k]); 212 Pair<Double,SGDVector> output = objective.valueAndGradient(sgdLabels[k],pred); 213 loss += output.getA()*weights[k]; 214 tempWeight += weights[k]; 215 216 gradients[k-j] = linearParameters.gradients(output,sgdFeatures[k]); 217 curSize++; 218 } 219 Tensor[] updates = linearParameters.merge(gradients,curSize); 220 for (int k = 0; k < updates.length; k++) { 221 updates[k].scaleInPlace(minibatchSize); 222 } 223 tempWeight /= minibatchSize; 224 updates = localOptimiser.step(updates,tempWeight); 225 linearParameters.update(updates); 226 227 iteration++; 228 if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) { 229 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 230 loss = 0.0; 231 } 232 } 233 } 234 } 235 localOptimiser.finalise(); 236 ModelProvenance provenance = new ModelProvenance(LinearSGDModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 237 //public LinearSGDModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) { 238 Model<Label> model = new LinearSGDModel("linear-sgd-model",provenance,featureIDMap,labelIDMap,linearParameters,objective.getNormalizer(),objective.isProbabilistic()); 239 localOptimiser.reset(); 240 return model; 241 } 242 243 @Override 244 public int getInvocationCount() { 245 return trainInvocationCounter; 246 } 247 248 @Override 249 public String toString() { 250 return "LinearSGDTrainer(objective="+objective.toString()+",optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")"; 251 } 252 253 @Override 254 public TrainerProvenance getProvenance() { 255 return new TrainerProvenanceImpl(this); 256 } 257}