001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.sgd.linear; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Trainer; 027import org.tribuo.WeightedExamples; 028import org.tribuo.math.LinearParameters; 029import org.tribuo.math.StochasticGradientOptimiser; 030import org.tribuo.math.la.DenseVector; 031import org.tribuo.math.la.SGDVector; 032import org.tribuo.math.la.SparseVector; 033import org.tribuo.math.la.Tensor; 034import org.tribuo.provenance.ModelProvenance; 035import org.tribuo.provenance.TrainerProvenance; 036import org.tribuo.provenance.impl.TrainerProvenanceImpl; 037import org.tribuo.regression.Regressor; 038import org.tribuo.regression.sgd.RegressionObjective; 039import org.tribuo.regression.sgd.Util; 040 041import java.time.OffsetDateTime; 042import java.util.Arrays; 043import java.util.Map; 044import java.util.SplittableRandom; 045import java.util.logging.Logger; 046 047/** 048 * A trainer for a linear regression model which uses SGD. 049 * Independently trains each output dimension, unless they are tied together in the 050 * optimiser. 051 * <p> 052 * See: 053 * <pre> 054 * Bottou L. 055 * "Large-Scale Machine Learning with Stochastic Gradient Descent" 056 * Proceedings of COMPSTAT, 2010. 057 * </pre> 058 */ 059public class LinearSGDTrainer implements Trainer<Regressor>, WeightedExamples { 060 private static final Logger logger = Logger.getLogger(LinearSGDTrainer.class.getName()); 061 062 @Config(mandatory = true,description="The regression objective to use.") 063 private RegressionObjective objective; 064 065 @Config(mandatory = true,description="The gradient optimiser to use.") 066 private StochasticGradientOptimiser optimiser; 067 068 @Config(description="The number of gradient descent epochs.") 069 private int epochs = 5; 070 071 @Config(description="Log values after this many updates.") 072 private int loggingInterval = -1; 073 074 @Config(description="Minibatch size in SGD.") 075 private int minibatchSize = 1; 076 077 @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.") 078 private long seed; 079 080 @Config(description="Shuffle the data before each epoch. Only turn off for debugging.") 081 private boolean shuffle = true; 082 083 private SplittableRandom rng; 084 085 private int trainInvocationCounter; 086 087 /** 088 * Constructs an SGD trainer for a linear model. 089 * @param objective The objective function to optimise. 090 * @param optimiser The gradient optimiser to use. 091 * @param epochs The number of epochs (complete passes through the training data). 092 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 093 * @param minibatchSize The size of any minibatches. 094 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 095 */ 096 public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) { 097 this.objective = objective; 098 this.optimiser = optimiser; 099 this.epochs = epochs; 100 this.loggingInterval = loggingInterval; 101 this.minibatchSize = minibatchSize; 102 this.seed = seed; 103 postConfig(); 104 } 105 106 /** 107 * Sets the minibatch size to 1. 108 * @param objective The objective function to optimise. 109 * @param optimiser The gradient optimiser to use. 110 * @param epochs The number of epochs (complete passes through the training data). 111 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 112 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 113 */ 114 public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) { 115 this(objective,optimiser,epochs,loggingInterval,1,seed); 116 } 117 118 /** 119 * Sets the minibatch size to 1 and the logging interval to 1000. 120 * @param objective The objective function to optimise. 121 * @param optimiser The gradient optimiser to use. 122 * @param epochs The number of epochs (complete passes through the training data). 123 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 124 */ 125 public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) { 126 this(objective,optimiser,epochs,1000,1,seed); 127 } 128 129 /** 130 * For olcut. 131 */ 132 private LinearSGDTrainer() { } 133 134 @Override 135 public synchronized void postConfig() { 136 this.rng = new SplittableRandom(seed); 137 } 138 139 /** 140 * Turn on or off shuffling of examples. 141 * <p> 142 * This isn't exposed in the constructor as it defaults to on. 143 * This method should be used for debugging. 144 * @param shuffle If true shuffle the examples, if false leave them in their current order. 145 */ 146 public void setShuffle(boolean shuffle) { 147 this.shuffle = shuffle; 148 } 149 150 @Override 151 public LinearSGDModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 152 if (examples.getOutputInfo().getUnknownCount() > 0) { 153 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 154 } 155 // Creates a new RNG, adds one to the invocation count, generates a local optimiser. 156 TrainerProvenance trainerProvenance; 157 SplittableRandom localRNG; 158 StochasticGradientOptimiser localOptimiser; 159 synchronized(this) { 160 localRNG = rng.split(); 161 localOptimiser = optimiser.copy(); 162 trainerProvenance = getProvenance(); 163 trainInvocationCounter++; 164 } 165 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 166 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 167 int numOutputs = outputInfo.size(); 168 SparseVector[] sgdFeatures = new SparseVector[examples.size()]; 169 DenseVector[] sgdLabels = new DenseVector[examples.size()]; 170 double[] weights = new double[examples.size()]; 171 int n = 0; 172 double[] regressorsBuffer = new double[numOutputs]; 173 for (Example<Regressor> example : examples) { 174 weights[n] = example.getWeight(); 175 sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true); 176 Arrays.fill(regressorsBuffer,0.0); 177 for (Regressor.DimensionTuple r : example.getOutput()) { 178 int id = outputInfo.getID(r); 179 regressorsBuffer[id] = r.getValue(); 180 } 181 sgdLabels[n] = DenseVector.createDenseVector(regressorsBuffer); 182 n++; 183 } 184 String[] dimensionNames = new String[numOutputs]; 185 for (Regressor r : outputInfo.getDomain()) { 186 int id = outputInfo.getID(r); 187 dimensionNames[id] = r.getNames()[0]; 188 } 189 logger.info(String.format("Training SGD regressor with %d examples", n)); 190 logger.info("Output variable " + outputInfo.toReadableString()); 191 192 // featureIDMap.size()+1 adds the bias feature. 193 LinearParameters linearParameters = new LinearParameters(featureIDMap.size()+1,dimensionNames.length); 194 195 localOptimiser.initialise(linearParameters); 196 double loss = 0.0; 197 int iteration = 0; 198 199 for (int i = 0; i < epochs; i++) { 200 if (shuffle) { 201 Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG); 202 } 203 if (minibatchSize == 1) { 204 for (int j = 0; j < sgdFeatures.length; j++) { 205 SGDVector pred = linearParameters.predict(sgdFeatures[j]); 206 Pair<Double,SGDVector> output = objective.loss(sgdLabels[j],pred); 207 loss += output.getA()*weights[j]; 208 209 Tensor[] updates = localOptimiser.step(linearParameters.gradients(output,sgdFeatures[j]),weights[j]); 210 linearParameters.update(updates); 211 212 iteration++; 213 if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) { 214 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 215 loss = 0.0; 216 } 217 } 218 } else { 219 Tensor[][] gradients = new Tensor[minibatchSize][]; 220 for (int j = 0; j < sgdFeatures.length; j += minibatchSize) { 221 double tempWeight = 0.0; 222 int curSize = 0; 223 for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) { 224 SGDVector pred = linearParameters.predict(sgdFeatures[k]); 225 Pair<Double,SGDVector> output = objective.loss(sgdLabels[k],pred); 226 loss += output.getA()*weights[k]; 227 tempWeight += weights[k]; 228 229 gradients[k-j] = linearParameters.gradients(output,sgdFeatures[k]); 230 curSize++; 231 } 232 Tensor[] updates = linearParameters.merge(gradients,curSize); 233 for (int k = 0; k < updates.length; k++) { 234 updates[k].scaleInPlace(minibatchSize); 235 } 236 tempWeight /= minibatchSize; 237 updates = localOptimiser.step(updates,tempWeight); 238 linearParameters.update(updates); 239 240 iteration++; 241 if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) { 242 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval); 243 loss = 0.0; 244 } 245 } 246 } 247 } 248 localOptimiser.finalise(); 249 ModelProvenance provenance = new ModelProvenance(LinearSGDModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 250 LinearSGDModel model = new LinearSGDModel("linear-sgd-model",dimensionNames,provenance,featureIDMap,outputInfo,linearParameters); 251 return model; 252 } 253 254 @Override 255 public int getInvocationCount() { 256 return trainInvocationCounter; 257 } 258 259 @Override 260 public String toString() { 261 return "LinearSGDTrainer(objective="+objective.toString()+",optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")"; 262 } 263 264 @Override 265 public TrainerProvenance getProvenance() { 266 return new TrainerProvenanceImpl(this); 267 } 268}