001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.kernel;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Trainer;
026import org.tribuo.WeightedExamples;
027import org.tribuo.classification.Label;
028import org.tribuo.classification.sgd.Util;
029import org.tribuo.math.kernel.Kernel;
030import org.tribuo.math.la.DenseMatrix;
031import org.tribuo.math.la.DenseVector;
032import org.tribuo.math.la.SGDVector;
033import org.tribuo.math.la.SparseVector;
034import org.tribuo.provenance.ModelProvenance;
035import org.tribuo.provenance.TrainerProvenance;
036import org.tribuo.provenance.impl.TrainerProvenanceImpl;
037
038import java.time.OffsetDateTime;
039import java.util.HashMap;
040import java.util.Map;
041import java.util.SplittableRandom;
042import java.util.logging.Logger;
043
044/**
045 * A trainer for a kernelised model using the Pegasos optimiser.
046 * <p>
047 * The Pegasos optimiser is extremely sensitive to the lambda parameter, and this
048 * value must be tuned to get good performance.
049 * <p>
050 * See:
051 * <pre>
052 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A
053 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM"
054 * Mathematical Programming, 2011.
055 * </pre>
056 */
057public class KernelSVMTrainer implements Trainer<Label>, WeightedExamples {
058    private static final Logger logger = Logger.getLogger(KernelSVMTrainer.class.getName());
059
060    @Config(mandatory = true,description="SVM kernel.")
061    private Kernel kernel;
062
063    @Config(mandatory = true,description="Step size.")
064    private double lambda;
065
066    @Config(description="Number of SGD epochs.")
067    private int epochs = 5;
068
069    @Config(description="Log values after this many updates.")
070    private int loggingInterval = -1;
071
072    @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.")
073    private long seed;
074
075    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
076    private boolean shuffle = true;
077
078    private SplittableRandom rng;
079
080    private int trainInvocationCounter;
081
082    /**
083     * Constructs a trainer for a kernel SVM model.
084     * @param kernel The kernel function to use as a similarity measure.
085     * @param epochs The number of epochs (complete passes through the training data).
086     * @param lambda l2 regulariser on the support vectors.
087     * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything.
088     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
089     */
090    public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, int loggingInterval, long seed) {
091        this.kernel = kernel;
092        this.lambda = lambda;
093        this.epochs = epochs;
094        this.loggingInterval = loggingInterval;
095        this.seed = seed;
096        postConfig();
097    }
098
099    /**
100     * Constructs a trainer for a kernel SVM model.
101     * Sets the logging interval to 1000.
102     * @param kernel The kernel function to use as a similarity measure.
103     * @param lambda l2 regulariser on the support vectors.
104     * @param epochs The number of epochs (complete passes through the training data).
105     * @param seed A seed for the random number generator, used to shuffle the examples before each epoch.
106     */
107    public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, long seed) {
108        this(kernel,lambda,epochs,1000,seed);
109    }
110
111    /**
112     * For olcut.
113     */
114    private KernelSVMTrainer() { }
115
116    @Override
117    public synchronized void postConfig() {
118        this.rng = new SplittableRandom(seed);
119    }
120
121    /**
122     * Turn on or off shuffling of examples.
123     * <p>
124     * This isn't exposed in the constructor as it defaults to on.
125     * This method should only be used for debugging.
126     * @param shuffle If true shuffle the examples, if false leave them in their current order.
127     */
128    public void setShuffle(boolean shuffle) {
129        this.shuffle = shuffle;
130    }
131
132    @Override
133    public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
134        if (examples.getOutputInfo().getUnknownCount() > 0) {
135            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
136        }
137        // Creates a new RNG, adds one to the invocation count.
138        TrainerProvenance trainerProvenance;
139        SplittableRandom localRNG;
140        synchronized(this) {
141            localRNG = rng.split();
142            trainerProvenance = getProvenance();
143            trainInvocationCounter++;
144        }
145        ImmutableOutputInfo<Label> labelIDMap = examples.getOutputIDInfo();
146        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
147        SparseVector[] sgdFeatures = new SparseVector[examples.size()];
148        int[] sgdLabels = new int[examples.size()];
149        double[] weights = new double[examples.size()];
150        int[] indices = new int[examples.size()];
151        int n = 0;
152        for (Example<Label> example : examples) {
153            weights[n] = example.getWeight();
154            sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true);
155            sgdLabels[n] = labelIDMap.getID(example.getOutput());
156            indices[n] = n;
157            n++;
158        }
159        logger.info(String.format("Training Kernel SVM with %d examples", n));
160        logger.info(labelIDMap.toReadableString());
161
162        double loss = 0.0;
163        int iteration = 0;
164        Map<Integer,SparseVector> supportVectors = new HashMap<>();
165        double[][] alphas = new double[labelIDMap.size()][examples.size()];
166
167        for (int i = 0; i < epochs; i++) {
168            if (shuffle) {
169                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, indices, localRNG);
170            }
171            for (int j = 0; j < sgdFeatures.length; j++) {
172                SGDVector pred = predict(sgdFeatures[j],supportVectors,alphas);
173                pred.add(sgdLabels[j],-1.0);
174                int predIndex = pred.indexOfMax();
175
176                if (sgdLabels[j] != predIndex) {
177                    loss += (pred.get(sgdLabels[j]) - pred.get(predIndex)) * weights[j];
178                    supportVectors.putIfAbsent(indices[j],sgdFeatures[j]);
179                    alphas[sgdLabels[j]][indices[j]] += weights[j];
180                }
181
182                iteration++;
183                if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
184                    logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval + " with " + supportVectors.size() + " support vectors.");
185                    loss = 0.0;
186                }
187            }
188            logger.fine("Finished epoch " + i);
189        }
190
191        DenseMatrix alphaMatrix = new DenseMatrix(alphas.length,supportVectors.size());
192        for (int i = 0; i < alphas.length; i++) {
193            int rowCounter = 0;
194            for (int j = 0; j < sgdFeatures.length; j++) {
195                if (supportVectors.containsKey(j)) {
196                    alphaMatrix.set(i, rowCounter, alphas[i][j]);
197                    rowCounter++;
198                }
199            }
200        }
201
202        int counter = 0;
203        SparseVector[] supportArray = new SparseVector[supportVectors.size()];
204        for (int i = 0; i < sgdFeatures.length; i++) {
205            SparseVector value = supportVectors.get(i);
206            if (value != null) {
207                supportArray[counter] = value;
208                counter++;
209            }
210        }
211
212        ModelProvenance provenance = new ModelProvenance(KernelSVMModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
213        //public KernelSVMModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, Kernel kernel, SparseVector[] supportVectors, DenseMatrix weights)
214        KernelSVMModel model = new KernelSVMModel("kernel-model",provenance,featureIDMap,labelIDMap,kernel,supportArray,alphaMatrix);
215        return model;
216    }
217
218    @Override
219    public int getInvocationCount() {
220        return trainInvocationCounter;
221    }
222
223    @Override
224    public String toString() {
225        return "KernelSVMTrainer(kernel="+kernel.toString()+",lambda="+lambda+",epochs="+epochs+",seed="+seed+")";
226    }
227
228    private SGDVector predict(SparseVector features, Map<Integer,SparseVector> sv, double[][] alphas) {
229        double[] score = new double[alphas.length];
230
231        for (Map.Entry<Integer, SparseVector> e : sv.entrySet()) {
232            double distance = kernel.similarity(features,e.getValue());
233            for (int i = 0; i < alphas.length; i++) {
234                score[i] += alphas[i][e.getKey()] * distance;
235            }
236        }
237
238        return DenseVector.createDenseVector(score);
239    }
240
241    @Override
242    public TrainerProvenance getProvenance() {
243        return new TrainerProvenanceImpl(this);
244    }
245}