001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.SparseModel;
027import org.tribuo.SparseTrainer;
028import org.tribuo.Trainer;
029import org.tribuo.math.la.DenseVector;
030import org.tribuo.math.la.SGDVector;
031import org.tribuo.math.la.SparseVector;
032import org.tribuo.math.la.VectorTuple;
033import org.tribuo.provenance.ModelProvenance;
034import org.tribuo.provenance.TrainerProvenance;
035import org.tribuo.provenance.impl.TrainerProvenanceImpl;
036import org.tribuo.regression.Regressor;
037import org.tribuo.regression.Regressor.DimensionTuple;
038import org.tribuo.util.Util;
039
040import java.time.OffsetDateTime;
041import java.util.Arrays;
042import java.util.Map;
043import java.util.SplittableRandom;
044import java.util.logging.Level;
045import java.util.logging.Logger;
046
047import static org.tribuo.math.la.VectorTuple.DELTA;
048
049/**
050 * An ElasticNet trainer that uses co-ordinate descent. Modelled after scikit-learn's sparse matrix implementation.
051 * Each output dimension is trained independently.
052 * <p>
053 * See:
054 * <pre>
055 * Friedman J, Hastie T, Tibshirani R.
056 * "Regularization Paths for Generalized Linear Models via Coordinate Descent"
057 * Journal of Statistical Software, 2010
058 * </pre>
059 */
060public class ElasticNetCDTrainer implements SparseTrainer<Regressor> {
061
062    private static final Logger logger = Logger.getLogger(ElasticNetCDTrainer.class.getName());
063
064    @Config(mandatory = true,description="Overall regularisation penalty.")
065    private double alpha;
066
067    @Config(mandatory = true,description="Ratio of l1 to l2 parameters.")
068    private double l1Ratio;
069
070    @Config(description="Tolerance on the error.")
071    private double tolerance = 1e-4;
072
073    @Config(description="Maximium number of iterations to run.")
074    private int maxIterations = 500;
075
076    @Config(description="Randomises the order in which the features are probed.")
077    private boolean randomise = false;
078
079    @Config(description="The seed for the RNG.")
080    private long seed = Trainer.DEFAULT_SEED;
081
082    private SplittableRandom rng;
083
084    private int trainInvocationCounter;
085
086    /**
087     * For olcut.
088     */
089    private ElasticNetCDTrainer() { }
090
091    public ElasticNetCDTrainer(double alpha, double l1Ratio) {
092        this(alpha,l1Ratio,1e-4,500,false,Trainer.DEFAULT_SEED);
093    }
094
095    public ElasticNetCDTrainer(double alpha, double l1Ratio, long seed) {
096        this(alpha,l1Ratio,1e-4,500,true,seed);
097    }
098
099    public ElasticNetCDTrainer(double alpha, double l1Ratio, double tolerance, int maxIterations, boolean randomise, long seed) {
100        this.alpha = alpha;
101        this.l1Ratio = l1Ratio;
102        this.tolerance = tolerance;
103        this.maxIterations = maxIterations;
104        this.randomise = randomise;
105        this.seed = seed;
106        postConfig();
107    }
108
109    @Override
110    public synchronized void postConfig() {
111        if ((l1Ratio < DELTA) || (l1Ratio > 1.0 + DELTA)) {
112            throw new PropertyException("l1Ratio","L1 Ratio must be between 0 and 1. Found value " + l1Ratio);
113        }
114        this.rng = new SplittableRandom(seed);
115    }
116
117    @Override
118    public SparseModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
119        if (examples.getOutputInfo().getUnknownCount() > 0) {
120            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
121        }
122        // Creates a new RNG, adds one to the invocation count, generates provenance.
123        TrainerProvenance trainerProvenance;
124        SplittableRandom localRNG;
125        synchronized(this) {
126            localRNG = rng.split();
127            trainerProvenance = getProvenance();
128            trainInvocationCounter++;
129        }
130        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
131        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
132        int numFeatures = featureIDMap.size();
133        int numOutputs = outputInfo.size();
134        int numExamples = examples.size();
135        SparseVector[] columns = SparseVector.transpose(examples,featureIDMap);
136        String[] dimensionNames = new String[numOutputs];
137        DenseVector[] regressionTargets = new DenseVector[numOutputs];
138        for (int i = 0; i < numOutputs; i++) {
139            dimensionNames[i] = outputInfo.getOutput(i).getNames()[0];
140            regressionTargets[i] = new DenseVector(numExamples);
141        }
142        int i = 0;
143        for (Example<Regressor> e : examples) {
144            int j = 0;
145            for (DimensionTuple d : e.getOutput()) {
146                regressionTargets[j].set(i, d.getValue());
147                j++;
148            }
149            i++;
150        }
151        double l1Penalty = alpha * l1Ratio * numExamples;
152        double l2Penalty = alpha * (1.0 - l1Ratio) * numExamples;
153
154        double[] featureMeans = calculateMeans(columns);
155        double[] featureVariances = new double[columns.length];
156        Arrays.fill(featureVariances,1.0);
157        boolean center = false;
158        for (i = 0; i < numFeatures; i++) {
159            if (Math.abs(featureMeans[i]) > DELTA) {
160                center = true;
161                break;
162            }
163        }
164        double[] columnNorms = new double[numFeatures];
165        int[] featureIndices = new int[numFeatures];
166
167        for (i = 0; i < numFeatures; i++) {
168            featureIndices[i] = i;
169            double variance = 0.0;
170            for (VectorTuple v : columns[i]) {
171                variance += (v.value - featureMeans[i]) * (v.value - featureMeans[i]);
172            }
173            columnNorms[i] = variance + (numExamples - columns[i].numActiveElements()) * featureMeans[i] * featureMeans[i];
174        }
175
176        ElasticNetState elState = new ElasticNetState(columns,featureIndices,featureMeans,columnNorms,l1Penalty,l2Penalty,center);
177
178        SparseVector[] outputWeights = new SparseVector[numOutputs];
179        double[] outputMeans = new double[numOutputs];
180        for (int j = 0; j < dimensionNames.length; j++) {
181            outputWeights[j] = trainSingleDimension(regressionTargets[j],elState,localRNG.split());
182            outputMeans[j] = regressionTargets[j].sum() / numExamples;
183        }
184        double[] outputVariances = new double[numOutputs];//calculateVariances(regressionTargets,outputMeans);
185        Arrays.fill(outputVariances,1.0);
186
187        ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(),examples.getProvenance(),trainerProvenance,runProvenance);
188        return new SparseLinearModel("elastic-net-model", dimensionNames, provenance, featureIDMap, outputInfo,
189                outputWeights, DenseVector.createDenseVector(featureMeans), DenseVector.createDenseVector(featureVariances),
190                outputMeans, outputVariances, false);
191    }
192
193    private SparseVector trainSingleDimension(DenseVector regressionTargets, ElasticNetState state, SplittableRandom localRNG) {
194        int numFeatures = state.numFeatures;
195        int numExamples = state.numExamples;
196        DenseVector residuals = regressionTargets.copy();
197        DenseVector weights = new DenseVector(numFeatures);
198        double targetTwoNorm = regressionTargets.twoNorm();
199        double newTolerance = tolerance * targetTwoNorm * targetTwoNorm;
200
201        double[] xTransposeR = new double[numFeatures];
202        double[] xTransposeAlpha = new double[numFeatures];
203
204        for (int i = 0; i < maxIterations; i++) {
205            double maxWeight = 0.0;
206            double maxUpdate = 0.0;
207
208            // If randomly selecting the features, permute the indices
209            if (randomise) {
210                Util.randpermInPlace(state.featureIndices,localRNG);
211            }
212
213            // Iterate through the features
214            for (int j = 0; j < numFeatures; j++) {
215                int feature = state.featureIndices[j];
216
217                if (Math.abs(state.columnNorms[feature]) < DELTA) {
218                    continue;
219                }
220
221                double oldWeight = weights.get(feature);
222
223                // Update residual
224                if (oldWeight != 0.0) {
225                    for (VectorTuple v : state.columns[feature]) {
226                        residuals.set(v.index, residuals.get(v.index) + (v.value * oldWeight));
227                    }
228                    if (state.center) {
229                        for (int k = 0; k < numExamples; k++) {
230                            residuals.set(k, residuals.get(k) - (state.featureMeans[feature] * oldWeight));
231                        }
232                    }
233                }
234
235                // Update the weights in the required direction
236                double curDot = residuals.dot(state.columns[feature]);
237                if (state.center) {
238                    curDot -= residuals.sum() * state.featureMeans[feature];
239                }
240                double newWeight = Math.signum(curDot) * Math.max(Math.abs(curDot) - state.l1Penalty, 0) / (state.columnNorms[feature] + state.l2Penalty);
241                weights.set(feature,newWeight);
242
243                // Update residual after step
244                if (newWeight != 0.0) {
245                    for (VectorTuple v : state.columns[feature]) {
246                        residuals.set(v.index, residuals.get(v.index) - (v.value * newWeight));
247                    }
248                    if (state.center) {
249                        for (int k = 0; k < numExamples; k++) {
250                            residuals.set(k, residuals.get(k) + (state.featureMeans[feature] * newWeight));
251                        }
252                    }
253                }
254
255                double curUpdate = Math.abs(newWeight - oldWeight);
256
257                if (curUpdate > maxUpdate) {
258                    maxUpdate = curUpdate;
259                }
260
261                double absNewWeight = Math.abs(newWeight);
262                if (absNewWeight > maxWeight) {
263                    maxWeight = absNewWeight;
264                }
265            }
266
267            //logger.log(Level.INFO, "Iteration " + i + ", average residual = " + residuals.sum()/numExamples);
268
269            // Check the termination condition
270            if ((maxWeight < DELTA) || (maxUpdate / maxWeight < tolerance) || (i == (maxIterations-1))) {
271                double residualSum = residuals.sum();
272
273                double maxAbsXTA = 0.0;
274                for (int j = 0; j < numFeatures; j++) {
275                    xTransposeR[j] = residuals.dot(state.columns[j]);
276
277                    if (state.center) {
278                        xTransposeR[j] -= state.featureMeans[j] * residualSum;
279                    }
280
281                    xTransposeAlpha[j] = xTransposeR[j] - state.l2Penalty * weights.get(j);
282
283                    double curAbs = Math.abs(xTransposeAlpha[j]);
284                    if (curAbs > maxAbsXTA) {
285                        maxAbsXTA = curAbs;
286                    }
287                }
288
289                double residualTwoNorm = residuals.twoNorm();
290                residualTwoNorm *= residualTwoNorm;
291
292                double weightsTwoNorm = weights.twoNorm();
293                weightsTwoNorm *= weightsTwoNorm;
294
295                double weightsOneNorm = weights.oneNorm();
296
297                double scalingFactor, dualityGap;
298                if (maxAbsXTA > state.l1Penalty) {
299                    scalingFactor = state.l1Penalty / maxAbsXTA;
300                    double alphaNorm = residualTwoNorm * scalingFactor * scalingFactor;
301                    dualityGap = 0.5 * (residualTwoNorm + alphaNorm);
302                } else {
303                    scalingFactor = 1.0;
304                    dualityGap = residualTwoNorm;
305                }
306
307                dualityGap += state.l1Penalty * weightsOneNorm - scalingFactor * residuals.dot(regressionTargets);
308                dualityGap += 0.5 * state.l2Penalty * (1 + (scalingFactor * scalingFactor)) * weightsTwoNorm;
309
310                if (dualityGap < newTolerance) {
311                    // All done, stop iterating.
312                    logger.log(Level.INFO,"Iteration: " + i + ", duality gap = " + dualityGap + ", tolerance = " + newTolerance);
313                    break;
314                }
315            }
316        }
317
318
319        return weights.sparsify();
320    }
321
322    @Override
323    public int getInvocationCount() {
324        return trainInvocationCounter;
325    }
326
327    @Override
328    public String toString() {
329        return "ElasticNetCDTrainer(alpha="+alpha+",l1Ratio="+l1Ratio+"" +
330                ",tolerance="+tolerance+",maxIterations="+maxIterations +
331                ",randomise="+randomise+",seed="+seed+")";
332    }
333
334    private static double[] calculateMeans(SGDVector[] columns) {
335        double[] means = new double[columns.length];
336
337        for (int i = 0; i < means.length; i++) {
338            means[i] = columns[i].sum() / columns[i].size();
339        }
340
341        return means;
342    }
343
344    private static double[] calculateVariances(SGDVector[] columns, double[] means) {
345        double[] variances = new double[columns.length];
346
347        for (int i = 0; i < variances.length; i++) {
348            variances[i] = columns[i].variance(means[i]);
349        }
350
351        return variances;
352    }
353
354    @Override
355    public TrainerProvenance getProvenance() {
356        return new TrainerProvenanceImpl(this);
357    }
358
359    /**
360     * Carrier type for the immutable elastic net state.
361     */
362    private static class ElasticNetState {
363        final SparseVector[] columns;
364        final int numFeatures;
365        final int numExamples;
366        final int[] featureIndices;
367        final double[] featureMeans;
368        final double[] columnNorms;
369        final double l1Penalty;
370        final double l2Penalty;
371        final boolean center;
372
373        public ElasticNetState(SparseVector[] columns, int[] featureIndices, double[] featureMeans, double[] columnNorms, double l1Penalty, double l2Penalty, boolean center) {
374            this.columns = columns;
375            this.numFeatures = columns.length;
376            this.numExamples = columns[0].size();
377            this.featureIndices = featureIndices;
378            this.featureMeans = featureMeans;
379            this.columnNorms = columnNorms;
380            this.l1Penalty = l1Penalty;
381            this.l2Penalty = l2Penalty;
382            this.center = center;
383        }
384    }
385}