001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.kernel; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Trainer; 026import org.tribuo.WeightedExamples; 027import org.tribuo.classification.Label; 028import org.tribuo.classification.sgd.Util; 029import org.tribuo.math.kernel.Kernel; 030import org.tribuo.math.la.DenseMatrix; 031import org.tribuo.math.la.DenseVector; 032import org.tribuo.math.la.SGDVector; 033import org.tribuo.math.la.SparseVector; 034import org.tribuo.provenance.ModelProvenance; 035import org.tribuo.provenance.TrainerProvenance; 036import org.tribuo.provenance.impl.TrainerProvenanceImpl; 037 038import java.time.OffsetDateTime; 039import java.util.HashMap; 040import java.util.Map; 041import java.util.SplittableRandom; 042import java.util.logging.Logger; 043 044/** 045 * A trainer for a kernelised model using the Pegasos optimiser. 046 * <p> 047 * The Pegasos optimiser is extremely sensitive to the lambda parameter, and this 048 * value must be tuned to get good performance. 049 * <p> 050 * See: 051 * <pre> 052 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A 053 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM" 054 * Mathematical Programming, 2011. 055 * </pre> 056 */ 057public class KernelSVMTrainer implements Trainer<Label>, WeightedExamples { 058 private static final Logger logger = Logger.getLogger(KernelSVMTrainer.class.getName()); 059 060 @Config(mandatory = true,description="SVM kernel.") 061 private Kernel kernel; 062 063 @Config(mandatory = true,description="Step size.") 064 private double lambda; 065 066 @Config(description="Number of SGD epochs.") 067 private int epochs = 5; 068 069 @Config(description="Log values after this many updates.") 070 private int loggingInterval = -1; 071 072 @Config(mandatory = true,description="Seed for the RNG used to shuffle elements.") 073 private long seed; 074 075 @Config(description="Shuffle the data before each epoch. Only turn off for debugging.") 076 private boolean shuffle = true; 077 078 private SplittableRandom rng; 079 080 private int trainInvocationCounter; 081 082 /** 083 * Constructs a trainer for a kernel SVM model. 084 * @param kernel The kernel function to use as a similarity measure. 085 * @param epochs The number of epochs (complete passes through the training data). 086 * @param lambda l2 regulariser on the support vectors. 087 * @param loggingInterval Log the loss after this many iterations. If -1 don't log anything. 088 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 089 */ 090 public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, int loggingInterval, long seed) { 091 this.kernel = kernel; 092 this.lambda = lambda; 093 this.epochs = epochs; 094 this.loggingInterval = loggingInterval; 095 this.seed = seed; 096 postConfig(); 097 } 098 099 /** 100 * Constructs a trainer for a kernel SVM model. 101 * Sets the logging interval to 1000. 102 * @param kernel The kernel function to use as a similarity measure. 103 * @param lambda l2 regulariser on the support vectors. 104 * @param epochs The number of epochs (complete passes through the training data). 105 * @param seed A seed for the random number generator, used to shuffle the examples before each epoch. 106 */ 107 public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, long seed) { 108 this(kernel,lambda,epochs,1000,seed); 109 } 110 111 /** 112 * For olcut. 113 */ 114 private KernelSVMTrainer() { } 115 116 @Override 117 public synchronized void postConfig() { 118 this.rng = new SplittableRandom(seed); 119 } 120 121 /** 122 * Turn on or off shuffling of examples. 123 * <p> 124 * This isn't exposed in the constructor as it defaults to on. 125 * This method should only be used for debugging. 126 * @param shuffle If true shuffle the examples, if false leave them in their current order. 127 */ 128 public void setShuffle(boolean shuffle) { 129 this.shuffle = shuffle; 130 } 131 132 @Override 133 public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 134 if (examples.getOutputInfo().getUnknownCount() > 0) { 135 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 136 } 137 // Creates a new RNG, adds one to the invocation count. 138 TrainerProvenance trainerProvenance; 139 SplittableRandom localRNG; 140 synchronized(this) { 141 localRNG = rng.split(); 142 trainerProvenance = getProvenance(); 143 trainInvocationCounter++; 144 } 145 ImmutableOutputInfo<Label> labelIDMap = examples.getOutputIDInfo(); 146 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 147 SparseVector[] sgdFeatures = new SparseVector[examples.size()]; 148 int[] sgdLabels = new int[examples.size()]; 149 double[] weights = new double[examples.size()]; 150 int[] indices = new int[examples.size()]; 151 int n = 0; 152 for (Example<Label> example : examples) { 153 weights[n] = example.getWeight(); 154 sgdFeatures[n] = SparseVector.createSparseVector(example,featureIDMap,true); 155 sgdLabels[n] = labelIDMap.getID(example.getOutput()); 156 indices[n] = n; 157 n++; 158 } 159 logger.info(String.format("Training Kernel SVM with %d examples", n)); 160 logger.info(labelIDMap.toReadableString()); 161 162 double loss = 0.0; 163 int iteration = 0; 164 Map<Integer,SparseVector> supportVectors = new HashMap<>(); 165 double[][] alphas = new double[labelIDMap.size()][examples.size()]; 166 167 for (int i = 0; i < epochs; i++) { 168 if (shuffle) { 169 Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, indices, localRNG); 170 } 171 for (int j = 0; j < sgdFeatures.length; j++) { 172 SGDVector pred = predict(sgdFeatures[j],supportVectors,alphas); 173 pred.add(sgdLabels[j],-1.0); 174 int predIndex = pred.indexOfMax(); 175 176 if (sgdLabels[j] != predIndex) { 177 loss += (pred.get(sgdLabels[j]) - pred.get(predIndex)) * weights[j]; 178 supportVectors.putIfAbsent(indices[j],sgdFeatures[j]); 179 alphas[sgdLabels[j]][indices[j]] += weights[j]; 180 } 181 182 iteration++; 183 if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) { 184 logger.info("At iteration " + iteration + ", average loss = " + loss/loggingInterval + " with " + supportVectors.size() + " support vectors."); 185 loss = 0.0; 186 } 187 } 188 logger.fine("Finished epoch " + i); 189 } 190 191 DenseMatrix alphaMatrix = new DenseMatrix(alphas.length,supportVectors.size()); 192 for (int i = 0; i < alphas.length; i++) { 193 int rowCounter = 0; 194 for (int j = 0; j < sgdFeatures.length; j++) { 195 if (supportVectors.containsKey(j)) { 196 alphaMatrix.set(i, rowCounter, alphas[i][j]); 197 rowCounter++; 198 } 199 } 200 } 201 202 int counter = 0; 203 SparseVector[] supportArray = new SparseVector[supportVectors.size()]; 204 for (int i = 0; i < sgdFeatures.length; i++) { 205 SparseVector value = supportVectors.get(i); 206 if (value != null) { 207 supportArray[counter] = value; 208 counter++; 209 } 210 } 211 212 ModelProvenance provenance = new ModelProvenance(KernelSVMModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 213 //public KernelSVMModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, Kernel kernel, SparseVector[] supportVectors, DenseMatrix weights) 214 KernelSVMModel model = new KernelSVMModel("kernel-model",provenance,featureIDMap,labelIDMap,kernel,supportArray,alphaMatrix); 215 return model; 216 } 217 218 @Override 219 public int getInvocationCount() { 220 return trainInvocationCounter; 221 } 222 223 @Override 224 public String toString() { 225 return "KernelSVMTrainer(kernel="+kernel.toString()+",lambda="+lambda+",epochs="+epochs+",seed="+seed+")"; 226 } 227 228 private SGDVector predict(SparseVector features, Map<Integer,SparseVector> sv, double[][] alphas) { 229 double[] score = new double[alphas.length]; 230 231 for (Map.Entry<Integer, SparseVector> e : sv.entrySet()) { 232 double distance = kernel.similarity(features,e.getValue()); 233 for (int i = 0; i < alphas.length; i++) { 234 score[i] += alphas[i][e.getKey()] * distance; 235 } 236 } 237 238 return DenseVector.createDenseVector(score); 239 } 240 241 @Override 242 public TrainerProvenance getProvenance() { 243 return new TrainerProvenanceImpl(this); 244 } 245}