001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.sgd.linear;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Trainer;
027import org.tribuo.WeightedExamples;
028import org.tribuo.math.LinearParameters;
029import org.tribuo.math.StochasticGradientOptimiser;
030import org.tribuo.math.la.DenseVector;
031import org.tribuo.math.la.SGDVector;
032import org.tribuo.math.la.SparseVector;
033import org.tribuo.math.la.Tensor;
034import org.tribuo.provenance.ModelProvenance;
035import org.tribuo.provenance.TrainerProvenance;
036import org.tribuo.provenance.impl.TrainerProvenanceImpl;
037import org.tribuo.regression.Regressor;
038import org.tribuo.regression.sgd.RegressionObjective;
039import org.tribuo.regression.sgd.Util;
040
041import java.time.OffsetDateTime;
042import java.util.Arrays;
043import java.util.Map;
044import java.util.SplittableRandom;
045import java.util.logging.Logger;
046
047/**
048 * A trainer for a linear regression model which uses SGD.
049 * Independently trains each output dimension, unless they are tied together in the
050 * optimiser.
051 * <p>
052 * See:
053 * <pre>
054 * Bottou L.
055 * "Large-Scale Machine Learning with Stochastic Gradient Descent"
056 * Proceedings of COMPSTAT, 2010.
057 * </pre>
058 */
059public class LinearSGDTrainer implements Trainer<Regressor>, WeightedExamples {
060    private static final Logger logger = Logger.getLogger(LinearSGDTrainer.class.getName());
061
062    @Config(mandatory = true,description="The regression objective to use.")
063    private RegressionObjective objective;
064
065    @Config(mandatory = true,description="The gradient optimiser to use.")
066    private StochasticGradientOptimiser optimiser;
067
068    @Config(description="The number of gradient descent epochs.")
069    private int epochs = 5;
070
071    @Config(description="Log values after this many updates.")
072    private int loggingInterval = -1;
073
074    @Config(description="Minibatch size in SGD.")
075    private int minibatchSize = 1;
076
077    @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.")
078    private long seed;
079
080    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
081    private boolean shuffle = true;
082
083    private SplittableRandom rng;
084
085    private int trainInvocationCounter;
086
087    /**
088     * Constructs an SGD trainer for a linear model.
089     * @param objective The objective function to optimise.
090     * @param optimiser The gradient optimiser to use.
091     * @param epochs The number of epochs (complete passes through the training data).
092     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
093     * @param minibatchSize The size of any minibatches.
094     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
095     */
096    public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) {
097        this.objective = objective;
098        this.optimiser = optimiser;
099        this.epochs = epochs;
100        this.loggingInterval = loggingInterval;
101        this.minibatchSize = minibatchSize;
102        this.seed = seed;
103        postConfig();
104    }
105
106    /**
107     * Sets the minibatch size to 1.
108     * @param objective The objective function to optimise.
109     * @param optimiser The gradient optimiser to use.
110     * @param epochs The number of epochs (complete passes through the training data).
111     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
112     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
113     */
114    public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) {
115        this(objective,optimiser,epochs,loggingInterval,1,seed);
116    }
117
118    /**
119     * Sets the minibatch size to 1 and the logging interval to 1000.
120     * @param objective The objective function to optimise.
121     * @param optimiser The gradient optimiser to use.
122     * @param epochs The number of epochs (complete passes through the training data).
123     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
124     */
125    public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) {
126        this(objective,optimiser,epochs,1000,1,seed);
127    }
128
129    /**
130     * For olcut.
131     */
132    private LinearSGDTrainer() { }
133
134    @Override
135    public synchronized void postConfig() {
136        this.rng = new SplittableRandom(seed);
137    }
138
139    /**
140     * Turn on or off shuffling of examples.
141     * <p>
142     * This isn't exposed in the constructor as it defaults to on.
143     * This method should be used for debugging.
144     * @param shuffle If true shuffle the examples, if false leave them in their current order.
145     */
146    public void setShuffle(boolean shuffle) {
147        this.shuffle = shuffle;
148    }
149
150    @Override
151    public LinearSGDModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
152        if (examples.getOutputInfo().getUnknownCount() > 0) {
153            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
154        }
155        // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
156        TrainerProvenance trainerProvenance;
157        SplittableRandom localRNG;
158        StochasticGradientOptimiser localOptimiser;
159        synchronized(this) {
160            localRNG = rng.split();
161            localOptimiser = optimiser.copy();
162            trainerProvenance = getProvenance();
163            trainInvocationCounter++;
164        }
165        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
166        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
167        int numOutputs = outputInfo.size();
168        SparseVector[] sgdFeatures = new SparseVector[examples.size()];
169        DenseVector[] sgdLabels = new DenseVector[examples.size()];
170        double[] weights = new double[examples.size()];
171        int n = 0;
172        double[] regressorsBuffer = new double[numOutputs];
173        for (Example<Regressor> example : examples) {
174            weights[n] = example.getWeight();
175            sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true);
176            Arrays.fill(regressorsBuffer,0.0);
177            for (Regressor.DimensionTuple r : example.getOutput()) {
178                int id = outputInfo.getID(r);
179                regressorsBuffer[id] = r.getValue();
180            }
181            sgdLabels[n] = DenseVector.createDenseVector(regressorsBuffer);
182            n++;
183        }
184        String[] dimensionNames = new String[numOutputs];
185        for (Regressor r : outputInfo.getDomain()) {
186            int id = outputInfo.getID(r);
187            dimensionNames[id] = r.getNames()[0];
188        }
189        logger.info(String.format("Training SGD regressor with %d examples", n));
190        logger.info("Output variable " + outputInfo.toReadableString());
191
192        // featureIDMap.size()+1 adds the bias feature.
193        LinearParameters linearParameters = new LinearParameters(featureIDMap.size()+1,dimensionNames.length);
194
195        localOptimiser.initialise(linearParameters);
196        double loss = 0.0;
197        int iteration = 0;
198
199        for (int i = 0; i < epochs; i++) {
200            if (shuffle) {
201                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
202            }
203            if (minibatchSize == 1) {
204                for (int j = 0; j < sgdFeatures.length; j++) {
205                    SGDVector pred = linearParameters.predict(sgdFeatures[j]);
206                    Pair<Double,SGDVector> output = objective.loss(sgdLabels[j],pred);
207                    loss += output.getA()*weights[j];
208
209                    Tensor[] updates = localOptimiser.step(linearParameters.gradients(output,sgdFeatures[j]),weights[j]);
210                    linearParameters.update(updates);
211
212                    iteration++;
213                    if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
214                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
215                        loss = 0.0;
216                    }
217                }
218            } else {
219                Tensor[][] gradients = new Tensor[minibatchSize][];
220                for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
221                    double tempWeight = 0.0;
222                    int curSize = 0;
223                    for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) {
224                        SGDVector pred = linearParameters.predict(sgdFeatures[k]);
225                        Pair<Double,SGDVector> output = objective.loss(sgdLabels[k],pred);
226                        loss += output.getA()*weights[k];
227                        tempWeight += weights[k];
228
229                        gradients[k-j] = linearParameters.gradients(output,sgdFeatures[k]);
230                        curSize++;
231                    }
232                    Tensor[] updates = linearParameters.merge(gradients,curSize);
233                    for (int k = 0; k < updates.length; k++) {
234                        updates[k].scaleInPlace(minibatchSize);
235                    }
236                    tempWeight /= minibatchSize;
237                    updates = localOptimiser.step(updates,tempWeight);
238                    linearParameters.update(updates);
239
240                    iteration++;
241                    if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
242                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
243                        loss = 0.0;
244                    }
245                }
246            }
247        }
248        localOptimiser.finalise();
249        ModelProvenance provenance = new ModelProvenance(LinearSGDModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
250        LinearSGDModel model = new LinearSGDModel("linear-sgd-model",dimensionNames,provenance,featureIDMap,outputInfo,linearParameters);
251        return model;
252    }
253
254    @Override
255    public int getInvocationCount() {
256        return trainInvocationCounter;
257    }
258
259    @Override
260    public String toString() {
261        return "LinearSGDTrainer(objective="+objective.toString()+",optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")";
262    }
263
264    @Override
265    public TrainerProvenance getProvenance() {
266        return new TrainerProvenanceImpl(this);
267    }
268}