001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.SparseTrainer; 026import org.tribuo.WeightedExamples; 027import org.tribuo.math.la.DenseVector; 028import org.tribuo.math.la.SparseVector; 029import org.tribuo.math.la.VectorTuple; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.provenance.TrainerProvenance; 032import org.tribuo.provenance.impl.TrainerProvenanceImpl; 033import org.tribuo.regression.Regressor; 034import org.tribuo.util.Util; 035import org.apache.commons.math3.linear.Array2DRowRealMatrix; 036import org.apache.commons.math3.linear.ArrayRealVector; 037import org.apache.commons.math3.linear.LUDecomposition; 038import org.apache.commons.math3.linear.RealMatrix; 039import org.apache.commons.math3.linear.RealVector; 040import org.apache.commons.math3.linear.SingularMatrixException; 041 042import java.time.OffsetDateTime; 043import java.util.ArrayList; 044import java.util.Arrays; 045import java.util.HashMap; 046import java.util.HashSet; 047import java.util.List; 048import java.util.Map; 049import java.util.Set; 050import java.util.logging.Level; 051import java.util.logging.Logger; 052 053/** 054 * A trainer for a sparse linear regression model. 055 * Uses sequential forward selection to construct the model. Optionally can 056 * normalize the data first. Each output dimension is trained independently 057 * with no shared regularization. 058 */ 059public class SLMTrainer implements SparseTrainer<Regressor>, WeightedExamples { 060 private static final Logger logger = Logger.getLogger(SLMTrainer.class.getName()); 061 062 @Config(description="Maximum number of features to use.") 063 protected int maxNumFeatures = -1; 064 065 @Config(description="Normalize the data first.") 066 protected boolean normalize; 067 068 protected int trainInvocationCounter = 0; 069 070 /** 071 * Constructs a trainer for a sparse linear model using sequential forward selection. 072 * 073 * @param normalize Normalizes the data first (i.e., removes the bias term). 074 * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features. 075 */ 076 public SLMTrainer(boolean normalize, int maxNumFeatures) { 077 this.normalize = normalize; 078 this.maxNumFeatures = maxNumFeatures; 079 } 080 081 /** 082 * Constructs a trainer for a sparse linear model using sequential forward selection. 083 * <p> 084 * Selects all the features. 085 * 086 * @param normalize Normalizes the data first (i.e., removes the bias term). 087 */ 088 public SLMTrainer(boolean normalize) { 089 this(normalize,-1); 090 } 091 092 /** 093 * For OLCUT. 094 */ 095 protected SLMTrainer() {} 096 097 protected RealVector newWeights(SLMState state) { 098 RealVector result = SLMTrainer.ordinaryLeastSquares(state.xpi,state.y); 099 100 if (result == null) { 101 return null; 102 } else { 103 return state.unpack(result); 104 } 105 } 106 107 /** 108 * Trains a sparse linear model. 109 * @param examples The data set containing the examples. 110 * @return A trained sparse linear model. 111 */ 112 @Override 113 public SparseLinearModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 114 if (examples.getOutputInfo().getUnknownCount() > 0) { 115 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 116 } 117 118 TrainerProvenance trainerProvenance; 119 synchronized(this) { 120 trainerProvenance = getProvenance(); 121 trainInvocationCounter++; 122 } 123 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 124 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 125 Set<Regressor> domain = outputInfo.getDomain(); 126 int numOutputs = outputInfo.size(); 127 int numExamples = examples.size(); 128 int numFeatures = normalize ? featureIDMap.size() : featureIDMap.size() + 1; //include bias 129 double[][] outputs = new double[numOutputs][numExamples]; 130 SparseVector[] inputs = new SparseVector[numExamples]; 131 int n = 0; 132 for (Example<Regressor> e : examples) { 133 inputs[n] = SparseVector.createSparseVector(e,featureIDMap,!normalize); 134 double curWeight = Math.sqrt(e.getWeight()); 135 inputs[n].scaleInPlace(curWeight); //rescale features by example weight 136 for (Regressor.DimensionTuple r : e.getOutput()) { 137 int id = outputInfo.getID(r); 138 outputs[id][n] = r.getValue() * curWeight; //rescale output by example weight 139 } 140 n++; 141 } 142 143 // Extract featureMatrix from the sparse vectors 144 RealMatrix featureMatrix = new Array2DRowRealMatrix(numExamples, numFeatures); 145 double[] denseFeatures = new double[numFeatures]; 146 for (int i = 0; i < inputs.length; i++) { 147 Arrays.fill(denseFeatures,0.0); 148 for (VectorTuple vec : inputs[i]) { 149 denseFeatures[vec.index] = vec.value; 150 } 151 featureMatrix.setRow(i, denseFeatures); 152 } 153 154 double[] featureMeans = new double[numFeatures]; 155 double[] featureVariances = new double[numFeatures]; 156 double[] outputMeans = new double[numOutputs]; 157 double[] outputVariances = new double[numOutputs]; 158 if (normalize) { 159 for (int i = 0; i < numFeatures; ++i) { 160 double[] featV = featureMatrix.getColumn(i); 161 featureMeans[i] = Util.mean(featV); 162 163 for (int j=0; j < featV.length; ++j) { 164 featV[j] -= featureMeans[i]; 165 } 166 167 RealVector xp = new ArrayRealVector(featV); 168 featureVariances[i] = xp.getNorm(); 169 featureMatrix.setColumnVector(i,xp.mapDivideToSelf(featureVariances[i])); 170 } 171 172 for (int i = 0; i < numOutputs; i++) { 173 outputMeans[i] = Util.mean(outputs[i]); 174 // Remove mean and aggregate variance 175 double sum = 0.0; 176 for (int j = 0; j < numExamples; j++) { 177 outputs[i][j] -= outputMeans[i]; 178 sum += outputs[i][j] * outputs[i][j]; 179 } 180 outputVariances[i] = Math.sqrt(sum); 181 // Remove variance 182 for (int j = 0; j < numExamples; j++) { 183 outputs[i][j] /= outputVariances[i]; 184 } 185 } 186 } else { 187 Arrays.fill(featureMeans,0.0); 188 Arrays.fill(featureVariances,1.0); 189 Arrays.fill(outputMeans,0.0); 190 Arrays.fill(outputVariances,1.0); 191 } 192 193 // Construct the output matrix from the double[][] after scaling 194 RealMatrix outputMatrix = new Array2DRowRealMatrix(outputs); 195 196 // Array example is useful to compute a submatrix 197 int[] exampleRows = new int[numExamples]; 198 for (int i = 0; i < numExamples; ++i) { 199 exampleRows[i] = i; 200 } 201 202 RealVector one = new ArrayRealVector(numExamples,1.0); 203 204 int numToSelect; 205 if ((maxNumFeatures < 1) || (maxNumFeatures > featureIDMap.size())) { 206 numToSelect = featureIDMap.size(); 207 } else { 208 numToSelect = maxNumFeatures; 209 } 210 211 String[] dimensionNames = new String[numOutputs]; 212 SparseVector[] modelWeights = new SparseVector[numOutputs]; 213 for (Regressor r : domain) { 214 int id = outputInfo.getID(r); 215 dimensionNames[id] = r.getNames()[0]; 216 SLMState state = new SLMState(featureMatrix,outputMatrix.getRowVector(id),featureIDMap,normalize); 217 modelWeights[id] = trainSingleDimension(state,exampleRows,numToSelect,one); 218 } 219 220 ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 221 return new SparseLinearModel("slm-model", dimensionNames, provenance, featureIDMap, outputInfo, modelWeights, 222 DenseVector.createDenseVector(featureMeans), DenseVector.createDenseVector(featureVariances), 223 outputMeans, outputVariances, !normalize); 224 } 225 226 @Override 227 public int getInvocationCount() { 228 return trainInvocationCounter; 229 } 230 231 @Override 232 public TrainerProvenance getProvenance() { 233 return new TrainerProvenanceImpl(this); 234 } 235 236 @Override 237 public String toString() { 238 return "SFSTrainer(normalize="+normalize+",maxNumFeatures="+maxNumFeatures+")"; 239 } 240 241 /** 242 * Trains a single dimension. 243 * @param state The state object to use. 244 * @param exampleRows An array with the row indices in. 245 * @param numToSelect The number of features to select. 246 * @param one A RealVector of ones. 247 * @return The sparse vector representing the learned feature weights. 248 */ 249 private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int numToSelect, RealVector one) { 250 int iter = 0; 251 while (state.active.size() < numToSelect) { 252 // Compute the residual 253 state.r = state.y.subtract(state.X.operate(state.beta)); 254 255 logger.info("At iteration " + iter + " Average residual " + state.r.dotProduct(one) / state.numExamples); 256 iter++; 257 // Compute the correlation 258 state.corr = state.X.transpose().operate(state.r); 259 260 // Identify most correlated feature 261 double max = -1; 262 int feature = -1; 263 for (int i = 0; i < state.numFeatures; ++i) { 264 if (!state.activeSet.contains(i)) { 265 double absCorr = Math.abs(state.corr.getEntry(i)); 266 267 if (absCorr > max) { 268 max = absCorr; 269 feature = i; 270 } 271 } 272 } 273 274 state.C = max; 275 276 state.active.add(feature); 277 state.activeSet.add(feature); 278 279 if (!state.normalize && (feature == state.numFeatures-1)) { 280 logger.info("Bias selected"); 281 } else { 282 logger.info("Feature selected: " + state.featureIDMap.get(feature).getName() + " (pos=" + feature + ")"); 283 } 284 285 // Compute the active matrix 286 int[] activeFeatures = Util.toPrimitiveInt(state.active); 287 state.xpi = state.X.getSubMatrix(exampleRows, activeFeatures); 288 289 if (state.active.size() == (numToSelect - 1)) { 290 state.last = true; 291 } 292 293 RealVector betapi = newWeights(state); 294 295 if (betapi == null) { 296 // Matrix was not invertible 297 logger.log(Level.INFO, "Stopping at feature " + state.active.size() + " matrix was no longer invertible."); 298 break; 299 } 300 301 state.beta = betapi; 302 } 303 304 Map<Integer, Double> parameters = new HashMap<>(); 305 306 for (int i = 0; i < state.numFeatures; ++i) { 307 if (state.beta.getEntry(i) != 0) { 308 parameters.put(i, state.beta.getEntry(i)); 309 } 310 } 311 312 return SparseVector.createSparseVector(state.numFeatures, parameters); 313 } 314 315 /** 316 * Minimize ordinary least squares. 317 * 318 * Returns null if the matrix is not invertible. 319 * @param M The matrix of features. 320 * @param target The vector of target values. 321 * @return The OLS solution for the supplied features. 322 */ 323 static RealVector ordinaryLeastSquares(RealMatrix M, RealVector target) { 324 RealMatrix inv; 325 try { 326 inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse(); 327 } catch (SingularMatrixException s) { 328 // Matrix is not invertible, there is nothing we can do 329 // We will let the caller decide what to do 330 return null; 331 } 332 333 return inv.multiply(M.transpose()).operate(target); 334 } 335 336 /** 337 * Sums inverted matrix. 338 * @param matrix The Matrix to operate on. 339 * @return The sum of the inverted matrix. 340 */ 341 static double sumInverted(RealMatrix matrix) { 342 // Why are we not trying to catch the potential exception? 343 // Because in the context of LARS, if we call this method, we know the matrix is invertible 344 RealMatrix inv = new LUDecomposition(matrix.transpose().multiply(matrix)).getSolver().getInverse(); 345 346 RealVector one = new ArrayRealVector(matrix.getColumnDimension(),1.0); 347 348 return one.dotProduct(inv.operate(one)); 349 } 350 351 /** 352 * Inverts the matrix, takes the dot product and scales it by the supplied value. 353 * @param M The matrix to invert. 354 * @param AA The value to scale by. 355 * @return The vector of feature values. 356 */ 357 static RealVector getwa(RealMatrix M, double AA) { 358 RealMatrix inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse(); 359 RealVector one = new ArrayRealVector(M.getColumnDimension(),1.0); 360 361 return inv.operate(one).mapMultiply(AA); 362 } 363 364 /** 365 * Calculates (M . v) . D^T 366 * Used in LARS. 367 * @param D A matrix. 368 * @param M A matrix. 369 * @param v A vector. 370 * @return (M . v) . D^T 371 */ 372 static RealVector getA(RealMatrix D, RealMatrix M, RealVector v) { 373 RealVector u = M.operate(v); 374 return D.transpose().operate(u); 375 } 376 377 static class SLMState { 378 protected final int numExamples; 379 protected final int numFeatures; 380 protected final boolean normalize; 381 protected final ImmutableFeatureMap featureIDMap; 382 383 protected final Set<Integer> activeSet; 384 protected final List<Integer> active; 385 386 protected final RealMatrix X; 387 protected final RealVector y; 388 389 protected RealMatrix xpi; 390 protected RealVector r; 391 protected RealVector beta; 392 393 protected double C; 394 protected RealVector corr; 395 396 protected Boolean last = false; 397 398 public SLMState(RealMatrix features, RealVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) { 399 this.numExamples = features.getRowDimension(); 400 this.numFeatures = features.getColumnDimension(); 401 this.featureIDMap = featureIDMap; 402 this.normalize = normalize; 403 this.active = new ArrayList<>(); 404 this.activeSet = new HashSet<>(); 405 this.beta = new ArrayRealVector(numFeatures); 406 this.X = features; 407 this.y = outputs; 408 } 409 410 /** 411 * Unpacks the active set into a dense vector using the values in values 412 * @param values The values. 413 * @return A dense vector representing the values at the active set indices. 414 */ 415 public RealVector unpack(RealVector values) { 416 RealVector u = new ArrayRealVector(numFeatures); 417 418 for (int i = 0; i < active.size(); ++i) { 419 u.setEntry(active.get(i), values.getEntry(i)); 420 } 421 422 return u; 423 } 424 } 425}