001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.SparseModel; 027import org.tribuo.SparseTrainer; 028import org.tribuo.Trainer; 029import org.tribuo.math.la.DenseVector; 030import org.tribuo.math.la.SGDVector; 031import org.tribuo.math.la.SparseVector; 032import org.tribuo.math.la.VectorTuple; 033import org.tribuo.provenance.ModelProvenance; 034import org.tribuo.provenance.TrainerProvenance; 035import org.tribuo.provenance.impl.TrainerProvenanceImpl; 036import org.tribuo.regression.Regressor; 037import org.tribuo.regression.Regressor.DimensionTuple; 038import org.tribuo.util.Util; 039 040import java.time.OffsetDateTime; 041import java.util.Arrays; 042import java.util.Map; 043import java.util.SplittableRandom; 044import java.util.logging.Level; 045import java.util.logging.Logger; 046 047import static org.tribuo.math.la.VectorTuple.DELTA; 048 049/** 050 * An ElasticNet trainer that uses co-ordinate descent. Modelled after scikit-learn's sparse matrix implementation. 051 * Each output dimension is trained independently. 052 * <p> 053 * See: 054 * <pre> 055 * Friedman J, Hastie T, Tibshirani R. 056 * "Regularization Paths for Generalized Linear Models via Coordinate Descent" 057 * Journal of Statistical Software, 2010 058 * </pre> 059 */ 060public class ElasticNetCDTrainer implements SparseTrainer<Regressor> { 061 062 private static final Logger logger = Logger.getLogger(ElasticNetCDTrainer.class.getName()); 063 064 @Config(mandatory = true,description="Overall regularisation penalty.") 065 private double alpha; 066 067 @Config(mandatory = true,description="Ratio of l1 to l2 parameters.") 068 private double l1Ratio; 069 070 @Config(description="Tolerance on the error.") 071 private double tolerance = 1e-4; 072 073 @Config(description="Maximium number of iterations to run.") 074 private int maxIterations = 500; 075 076 @Config(description="Randomises the order in which the features are probed.") 077 private boolean randomise = false; 078 079 @Config(description="The seed for the RNG.") 080 private long seed = Trainer.DEFAULT_SEED; 081 082 private SplittableRandom rng; 083 084 private int trainInvocationCounter; 085 086 /** 087 * For olcut. 088 */ 089 private ElasticNetCDTrainer() { } 090 091 public ElasticNetCDTrainer(double alpha, double l1Ratio) { 092 this(alpha,l1Ratio,1e-4,500,false,Trainer.DEFAULT_SEED); 093 } 094 095 public ElasticNetCDTrainer(double alpha, double l1Ratio, long seed) { 096 this(alpha,l1Ratio,1e-4,500,true,seed); 097 } 098 099 public ElasticNetCDTrainer(double alpha, double l1Ratio, double tolerance, int maxIterations, boolean randomise, long seed) { 100 this.alpha = alpha; 101 this.l1Ratio = l1Ratio; 102 this.tolerance = tolerance; 103 this.maxIterations = maxIterations; 104 this.randomise = randomise; 105 this.seed = seed; 106 postConfig(); 107 } 108 109 @Override 110 public synchronized void postConfig() { 111 if ((l1Ratio < DELTA) || (l1Ratio > 1.0 + DELTA)) { 112 throw new PropertyException("l1Ratio","L1 Ratio must be between 0 and 1. Found value " + l1Ratio); 113 } 114 this.rng = new SplittableRandom(seed); 115 } 116 117 @Override 118 public SparseModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 119 if (examples.getOutputInfo().getUnknownCount() > 0) { 120 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 121 } 122 // Creates a new RNG, adds one to the invocation count, generates provenance. 123 TrainerProvenance trainerProvenance; 124 SplittableRandom localRNG; 125 synchronized(this) { 126 localRNG = rng.split(); 127 trainerProvenance = getProvenance(); 128 trainInvocationCounter++; 129 } 130 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 131 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 132 int numFeatures = featureIDMap.size(); 133 int numOutputs = outputInfo.size(); 134 int numExamples = examples.size(); 135 SparseVector[] columns = SparseVector.transpose(examples,featureIDMap); 136 String[] dimensionNames = new String[numOutputs]; 137 DenseVector[] regressionTargets = new DenseVector[numOutputs]; 138 for (int i = 0; i < numOutputs; i++) { 139 dimensionNames[i] = outputInfo.getOutput(i).getNames()[0]; 140 regressionTargets[i] = new DenseVector(numExamples); 141 } 142 int i = 0; 143 for (Example<Regressor> e : examples) { 144 int j = 0; 145 for (DimensionTuple d : e.getOutput()) { 146 regressionTargets[j].set(i, d.getValue()); 147 j++; 148 } 149 i++; 150 } 151 double l1Penalty = alpha * l1Ratio * numExamples; 152 double l2Penalty = alpha * (1.0 - l1Ratio) * numExamples; 153 154 double[] featureMeans = calculateMeans(columns); 155 double[] featureVariances = new double[columns.length]; 156 Arrays.fill(featureVariances,1.0); 157 boolean center = false; 158 for (i = 0; i < numFeatures; i++) { 159 if (Math.abs(featureMeans[i]) > DELTA) { 160 center = true; 161 break; 162 } 163 } 164 double[] columnNorms = new double[numFeatures]; 165 int[] featureIndices = new int[numFeatures]; 166 167 for (i = 0; i < numFeatures; i++) { 168 featureIndices[i] = i; 169 double variance = 0.0; 170 for (VectorTuple v : columns[i]) { 171 variance += (v.value - featureMeans[i]) * (v.value - featureMeans[i]); 172 } 173 columnNorms[i] = variance + (numExamples - columns[i].numActiveElements()) * featureMeans[i] * featureMeans[i]; 174 } 175 176 ElasticNetState elState = new ElasticNetState(columns,featureIndices,featureMeans,columnNorms,l1Penalty,l2Penalty,center); 177 178 SparseVector[] outputWeights = new SparseVector[numOutputs]; 179 double[] outputMeans = new double[numOutputs]; 180 for (int j = 0; j < dimensionNames.length; j++) { 181 outputWeights[j] = trainSingleDimension(regressionTargets[j],elState,localRNG.split()); 182 outputMeans[j] = regressionTargets[j].sum() / numExamples; 183 } 184 double[] outputVariances = new double[numOutputs];//calculateVariances(regressionTargets,outputMeans); 185 Arrays.fill(outputVariances,1.0); 186 187 ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(),examples.getProvenance(),trainerProvenance,runProvenance); 188 return new SparseLinearModel("elastic-net-model", dimensionNames, provenance, featureIDMap, outputInfo, 189 outputWeights, DenseVector.createDenseVector(featureMeans), DenseVector.createDenseVector(featureVariances), 190 outputMeans, outputVariances, false); 191 } 192 193 private SparseVector trainSingleDimension(DenseVector regressionTargets, ElasticNetState state, SplittableRandom localRNG) { 194 int numFeatures = state.numFeatures; 195 int numExamples = state.numExamples; 196 DenseVector residuals = regressionTargets.copy(); 197 DenseVector weights = new DenseVector(numFeatures); 198 double targetTwoNorm = regressionTargets.twoNorm(); 199 double newTolerance = tolerance * targetTwoNorm * targetTwoNorm; 200 201 double[] xTransposeR = new double[numFeatures]; 202 double[] xTransposeAlpha = new double[numFeatures]; 203 204 for (int i = 0; i < maxIterations; i++) { 205 double maxWeight = 0.0; 206 double maxUpdate = 0.0; 207 208 // If randomly selecting the features, permute the indices 209 if (randomise) { 210 Util.randpermInPlace(state.featureIndices,localRNG); 211 } 212 213 // Iterate through the features 214 for (int j = 0; j < numFeatures; j++) { 215 int feature = state.featureIndices[j]; 216 217 if (Math.abs(state.columnNorms[feature]) < DELTA) { 218 continue; 219 } 220 221 double oldWeight = weights.get(feature); 222 223 // Update residual 224 if (oldWeight != 0.0) { 225 for (VectorTuple v : state.columns[feature]) { 226 residuals.set(v.index, residuals.get(v.index) + (v.value * oldWeight)); 227 } 228 if (state.center) { 229 for (int k = 0; k < numExamples; k++) { 230 residuals.set(k, residuals.get(k) - (state.featureMeans[feature] * oldWeight)); 231 } 232 } 233 } 234 235 // Update the weights in the required direction 236 double curDot = residuals.dot(state.columns[feature]); 237 if (state.center) { 238 curDot -= residuals.sum() * state.featureMeans[feature]; 239 } 240 double newWeight = Math.signum(curDot) * Math.max(Math.abs(curDot) - state.l1Penalty, 0) / (state.columnNorms[feature] + state.l2Penalty); 241 weights.set(feature,newWeight); 242 243 // Update residual after step 244 if (newWeight != 0.0) { 245 for (VectorTuple v : state.columns[feature]) { 246 residuals.set(v.index, residuals.get(v.index) - (v.value * newWeight)); 247 } 248 if (state.center) { 249 for (int k = 0; k < numExamples; k++) { 250 residuals.set(k, residuals.get(k) + (state.featureMeans[feature] * newWeight)); 251 } 252 } 253 } 254 255 double curUpdate = Math.abs(newWeight - oldWeight); 256 257 if (curUpdate > maxUpdate) { 258 maxUpdate = curUpdate; 259 } 260 261 double absNewWeight = Math.abs(newWeight); 262 if (absNewWeight > maxWeight) { 263 maxWeight = absNewWeight; 264 } 265 } 266 267 //logger.log(Level.INFO, "Iteration " + i + ", average residual = " + residuals.sum()/numExamples); 268 269 // Check the termination condition 270 if ((maxWeight < DELTA) || (maxUpdate / maxWeight < tolerance) || (i == (maxIterations-1))) { 271 double residualSum = residuals.sum(); 272 273 double maxAbsXTA = 0.0; 274 for (int j = 0; j < numFeatures; j++) { 275 xTransposeR[j] = residuals.dot(state.columns[j]); 276 277 if (state.center) { 278 xTransposeR[j] -= state.featureMeans[j] * residualSum; 279 } 280 281 xTransposeAlpha[j] = xTransposeR[j] - state.l2Penalty * weights.get(j); 282 283 double curAbs = Math.abs(xTransposeAlpha[j]); 284 if (curAbs > maxAbsXTA) { 285 maxAbsXTA = curAbs; 286 } 287 } 288 289 double residualTwoNorm = residuals.twoNorm(); 290 residualTwoNorm *= residualTwoNorm; 291 292 double weightsTwoNorm = weights.twoNorm(); 293 weightsTwoNorm *= weightsTwoNorm; 294 295 double weightsOneNorm = weights.oneNorm(); 296 297 double scalingFactor, dualityGap; 298 if (maxAbsXTA > state.l1Penalty) { 299 scalingFactor = state.l1Penalty / maxAbsXTA; 300 double alphaNorm = residualTwoNorm * scalingFactor * scalingFactor; 301 dualityGap = 0.5 * (residualTwoNorm + alphaNorm); 302 } else { 303 scalingFactor = 1.0; 304 dualityGap = residualTwoNorm; 305 } 306 307 dualityGap += state.l1Penalty * weightsOneNorm - scalingFactor * residuals.dot(regressionTargets); 308 dualityGap += 0.5 * state.l2Penalty * (1 + (scalingFactor * scalingFactor)) * weightsTwoNorm; 309 310 if (dualityGap < newTolerance) { 311 // All done, stop iterating. 312 logger.log(Level.INFO,"Iteration: " + i + ", duality gap = " + dualityGap + ", tolerance = " + newTolerance); 313 break; 314 } 315 } 316 } 317 318 319 return weights.sparsify(); 320 } 321 322 @Override 323 public int getInvocationCount() { 324 return trainInvocationCounter; 325 } 326 327 @Override 328 public String toString() { 329 return "ElasticNetCDTrainer(alpha="+alpha+",l1Ratio="+l1Ratio+"" + 330 ",tolerance="+tolerance+",maxIterations="+maxIterations + 331 ",randomise="+randomise+",seed="+seed+")"; 332 } 333 334 private static double[] calculateMeans(SGDVector[] columns) { 335 double[] means = new double[columns.length]; 336 337 for (int i = 0; i < means.length; i++) { 338 means[i] = columns[i].sum() / columns[i].size(); 339 } 340 341 return means; 342 } 343 344 private static double[] calculateVariances(SGDVector[] columns, double[] means) { 345 double[] variances = new double[columns.length]; 346 347 for (int i = 0; i < variances.length; i++) { 348 variances[i] = columns[i].variance(means[i]); 349 } 350 351 return variances; 352 } 353 354 @Override 355 public TrainerProvenance getProvenance() { 356 return new TrainerProvenanceImpl(this); 357 } 358 359 /** 360 * Carrier type for the immutable elastic net state. 361 */ 362 private static class ElasticNetState { 363 final SparseVector[] columns; 364 final int numFeatures; 365 final int numExamples; 366 final int[] featureIndices; 367 final double[] featureMeans; 368 final double[] columnNorms; 369 final double l1Penalty; 370 final double l2Penalty; 371 final boolean center; 372 373 public ElasticNetState(SparseVector[] columns, int[] featureIndices, double[] featureMeans, double[] columnNorms, double l1Penalty, double l2Penalty, boolean center) { 374 this.columns = columns; 375 this.numFeatures = columns.length; 376 this.numExamples = columns[0].size(); 377 this.featureIndices = featureIndices; 378 this.featureMeans = featureMeans; 379 this.columnNorms = columnNorms; 380 this.l1Penalty = l1Penalty; 381 this.l2Penalty = l2Penalty; 382 this.center = center; 383 } 384 } 385}