001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.linear;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Model;
027import org.tribuo.Trainer;
028import org.tribuo.WeightedExamples;
029import org.tribuo.classification.Label;
030import org.tribuo.classification.sgd.LabelObjective;
031import org.tribuo.classification.sgd.Util;
032import org.tribuo.classification.sgd.objectives.LogMulticlass;
033import org.tribuo.math.LinearParameters;
034import org.tribuo.math.StochasticGradientOptimiser;
035import org.tribuo.math.la.SGDVector;
036import org.tribuo.math.la.SparseVector;
037import org.tribuo.math.la.Tensor;
038import org.tribuo.math.optimisers.AdaGrad;
039import org.tribuo.provenance.ModelProvenance;
040import org.tribuo.provenance.TrainerProvenance;
041import org.tribuo.provenance.impl.TrainerProvenanceImpl;
042
043import java.time.OffsetDateTime;
044import java.util.Map;
045import java.util.SplittableRandom;
046import java.util.logging.Logger;
047
048/**
049 * A trainer for a linear model which uses SGD.
050 * <p>
051 * See:
052 * <pre>
053 * Bottou L.
054 * "Large-Scale Machine Learning with Stochastic Gradient Descent"
055 * Proceedings of COMPSTAT, 2010.
056 * </pre>
057 */
058public class LinearSGDTrainer implements Trainer<Label>, WeightedExamples {
059    private static final Logger logger = Logger.getLogger(LinearSGDTrainer.class.getName());
060
061    @Config(description="The classification objective function to use.")
062    private LabelObjective objective = new LogMulticlass();
063
064    @Config(description="The gradient optimiser to use.")
065    private StochasticGradientOptimiser optimiser = new AdaGrad(1.0,0.1);
066
067    @Config(description="The number of gradient descent epochs.")
068    private int epochs = 5;
069
070    @Config(description="Log values after this many updates.")
071    private int loggingInterval = -1;
072
073    @Config(description="Minibatch size in SGD.")
074    private int minibatchSize = 1;
075
076    @Config(description="Seed for the RNG used to shuffle elements.")
077    private long seed = Trainer.DEFAULT_SEED;
078
079    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
080    private boolean shuffle = true;
081
082    private SplittableRandom rng;
083
084    private int trainInvocationCounter;
085
086    /**
087     * Constructs an SGD trainer for a linear model.
088     * @param objective The objective function to optimise.
089     * @param optimiser The gradient optimiser to use.
090     * @param epochs The number of epochs (complete passes through the training data).
091     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
092     * @param minibatchSize The size of any minibatches.
093     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
094     */
095    public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) {
096        this.objective = objective;
097        this.optimiser = optimiser;
098        this.epochs = epochs;
099        this.loggingInterval = loggingInterval;
100        this.minibatchSize = minibatchSize;
101        this.seed = seed;
102        postConfig();
103    }
104
105    /**
106     * Sets the minibatch size to 1.
107     * @param objective The objective function to optimise.
108     * @param optimiser The gradient optimiser to use.
109     * @param epochs The number of epochs (complete passes through the training data).
110     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
111     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
112     */
113    public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) {
114        this(objective,optimiser,epochs,loggingInterval,1,seed);
115    }
116
117    /**
118     * Sets the minibatch size to 1 and the logging interval to 1000.
119     * @param objective The objective function to optimise.
120     * @param optimiser The gradient optimiser to use.
121     * @param epochs The number of epochs (complete passes through the training data).
122     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
123     */
124    public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) {
125        this(objective,optimiser,epochs,1000,1,seed);
126    }
127
128    /**
129     * For olcut.
130     */
131    private LinearSGDTrainer() { }
132
133    @Override
134    public synchronized void postConfig() {
135        this.rng = new SplittableRandom(seed);
136    }
137
138    /**
139     * Turn on or off shuffling of examples.
140     * <p>
141     * This isn't exposed in the constructor as it defaults to on.
142     * This method should only be used for debugging.
143     * @param shuffle If true shuffle the examples, if false leave them in their current order.
144     */
145    public void setShuffle(boolean shuffle) {
146        this.shuffle = shuffle;
147    }
148
149    @Override
150    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
151        if (examples.getOutputInfo().getUnknownCount() > 0) {
152            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
153        }
154        // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
155        TrainerProvenance trainerProvenance;
156        SplittableRandom localRNG;
157        StochasticGradientOptimiser localOptimiser;
158        synchronized(this) {
159            localRNG = rng.split();
160            localOptimiser = optimiser.copy();
161            trainerProvenance = getProvenance();
162            trainInvocationCounter++;
163        }
164        ImmutableOutputInfo<Label> labelIDMap = examples.getOutputIDInfo();
165        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
166        SparseVector[] sgdFeatures = new SparseVector[examples.size()];
167        int[] sgdLabels = new int[examples.size()];
168        double[] weights = new double[examples.size()];
169        int n = 0;
170        for (Example<Label> example : examples) {
171            weights[n] = example.getWeight();
172            sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true);
173            sgdLabels[n] = labelIDMap.getID(example.getOutput());
174            n++;
175        }
176        logger.info(String.format("Training SGD classifier with %d examples", n));
177        logger.info("Labels - " + labelIDMap.toReadableString());
178
179        // featureIDMap.size()+1 adds the bias feature.
180        LinearParameters linearParameters = new LinearParameters(featureIDMap.size()+1,labelIDMap.size());
181
182        localOptimiser.initialise(linearParameters);
183        double loss = 0.0;
184        int iteration = 0;
185
186        for (int i = 0; i < epochs; i++) {
187            if (shuffle) {
188                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
189            }
190            if (minibatchSize == 1) {
191                for (int j = 0; j < sgdFeatures.length; j++) {
192                    SGDVector pred = linearParameters.predict(sgdFeatures[j]);
193                    Pair<Double,SGDVector> output = objective.valueAndGradient(sgdLabels[j],pred);
194                    loss += output.getA()*weights[j];
195
196                    Tensor[] updates = localOptimiser.step(linearParameters.gradients(output,sgdFeatures[j]),weights[j]);
197                    linearParameters.update(updates);
198
199                    iteration++;
200                    if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
201                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
202                        loss = 0.0;
203                    }
204                }
205            } else {
206                Tensor[][] gradients = new Tensor[minibatchSize][];
207                for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
208                    double tempWeight = 0.0;
209                    int curSize = 0;
210                    for (int k = j; k < j+minibatchSize && k < sgdFeatures.length; k++) {
211                        SGDVector pred = linearParameters.predict(sgdFeatures[k]);
212                        Pair<Double,SGDVector> output = objective.valueAndGradient(sgdLabels[k],pred);
213                        loss += output.getA()*weights[k];
214                        tempWeight += weights[k];
215
216                        gradients[k-j] = linearParameters.gradients(output,sgdFeatures[k]);
217                        curSize++;
218                    }
219                    Tensor[] updates = linearParameters.merge(gradients,curSize);
220                    for (int k = 0; k < updates.length; k++) {
221                        updates[k].scaleInPlace(minibatchSize);
222                    }
223                    tempWeight /= minibatchSize;
224                    updates = localOptimiser.step(updates,tempWeight);
225                    linearParameters.update(updates);
226
227                    iteration++;
228                    if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
229                        logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval);
230                        loss = 0.0;
231                    }
232                }
233            }
234        }
235        localOptimiser.finalise();
236        ModelProvenance provenance = new ModelProvenance(LinearSGDModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
237        //public LinearSGDModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) {
238        Model<Label> model = new LinearSGDModel("linear-sgd-model",provenance,featureIDMap,labelIDMap,linearParameters,objective.getNormalizer(),objective.isProbabilistic());
239        localOptimiser.reset();
240        return model;
241    }
242
243    @Override
244    public int getInvocationCount() {
245        return trainInvocationCounter;
246    }
247
248    @Override
249    public String toString() {
250        return "LinearSGDTrainer(objective="+objective.toString()+",optimiser="+optimiser.toString()+",epochs="+epochs+",minibatchSize="+minibatchSize+",seed="+seed+")";
251    }
252
253    @Override
254    public TrainerProvenance getProvenance() {
255        return new TrainerProvenanceImpl(this);
256    }
257}