001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.SparseTrainer;
026import org.tribuo.WeightedExamples;
027import org.tribuo.math.la.DenseVector;
028import org.tribuo.math.la.SparseVector;
029import org.tribuo.math.la.VectorTuple;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.provenance.TrainerProvenance;
032import org.tribuo.provenance.impl.TrainerProvenanceImpl;
033import org.tribuo.regression.Regressor;
034import org.tribuo.util.Util;
035import org.apache.commons.math3.linear.Array2DRowRealMatrix;
036import org.apache.commons.math3.linear.ArrayRealVector;
037import org.apache.commons.math3.linear.LUDecomposition;
038import org.apache.commons.math3.linear.RealMatrix;
039import org.apache.commons.math3.linear.RealVector;
040import org.apache.commons.math3.linear.SingularMatrixException;
041
042import java.time.OffsetDateTime;
043import java.util.ArrayList;
044import java.util.Arrays;
045import java.util.HashMap;
046import java.util.HashSet;
047import java.util.List;
048import java.util.Map;
049import java.util.Set;
050import java.util.logging.Level;
051import java.util.logging.Logger;
052
053/**
054 * A trainer for a sparse linear regression model.
055 * Uses sequential forward selection to construct the model. Optionally can
056 * normalize the data first. Each output dimension is trained independently
057 * with no shared regularization.
058 */
059public class SLMTrainer implements SparseTrainer<Regressor>, WeightedExamples {
060    private static final Logger logger = Logger.getLogger(SLMTrainer.class.getName());
061
062    @Config(description="Maximum number of features to use.")
063    protected int maxNumFeatures = -1;
064
065    @Config(description="Normalize the data first.")
066    protected boolean normalize;
067
068    protected int trainInvocationCounter = 0;
069
070    /**
071     * Constructs a trainer for a sparse linear model using sequential forward selection.
072     *
073     * @param normalize Normalizes the data first (i.e., removes the bias term).
074     * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features.
075     */
076    public SLMTrainer(boolean normalize, int maxNumFeatures) {
077        this.normalize = normalize;
078        this.maxNumFeatures = maxNumFeatures;
079    }
080
081    /**
082     * Constructs a trainer for a sparse linear model using sequential forward selection.
083     * <p>
084     * Selects all the features.
085     *
086     * @param normalize Normalizes the data first (i.e., removes the bias term).
087     */
088    public SLMTrainer(boolean normalize) {
089        this(normalize,-1);
090    }
091
092    /**
093     * For OLCUT.
094     */
095    protected SLMTrainer() {}
096
097    protected RealVector newWeights(SLMState state) {
098        RealVector result = SLMTrainer.ordinaryLeastSquares(state.xpi,state.y);
099
100        if (result == null) {
101            return null;
102        } else {
103            return state.unpack(result);
104        }
105    }
106
107    /**
108     * Trains a sparse linear model.
109     * @param examples The data set containing the examples.
110     * @return A trained sparse linear model.
111     */
112    @Override
113    public SparseLinearModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
114        if (examples.getOutputInfo().getUnknownCount() > 0) {
115            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
116        }
117
118        TrainerProvenance trainerProvenance;
119        synchronized(this) {
120            trainerProvenance = getProvenance();
121            trainInvocationCounter++;
122        }
123        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
124        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
125        Set<Regressor> domain = outputInfo.getDomain();
126        int numOutputs = outputInfo.size();
127        int numExamples = examples.size();
128        int numFeatures = normalize ? featureIDMap.size() : featureIDMap.size() + 1; //include bias
129        double[][] outputs = new double[numOutputs][numExamples];
130        SparseVector[] inputs = new SparseVector[numExamples];
131        int n = 0;
132        for (Example<Regressor> e : examples) {
133            inputs[n] = SparseVector.createSparseVector(e,featureIDMap,!normalize);
134            double curWeight = Math.sqrt(e.getWeight());
135            inputs[n].scaleInPlace(curWeight); //rescale features by example weight
136            for (Regressor.DimensionTuple r : e.getOutput()) {
137                int id = outputInfo.getID(r);
138                outputs[id][n] = r.getValue() * curWeight; //rescale output by example weight
139            }
140            n++;
141        }
142
143        // Extract featureMatrix from the sparse vectors
144        RealMatrix featureMatrix = new Array2DRowRealMatrix(numExamples, numFeatures);
145        double[] denseFeatures = new double[numFeatures];
146        for (int i = 0; i < inputs.length; i++) {
147            Arrays.fill(denseFeatures,0.0);
148            for (VectorTuple vec : inputs[i]) {
149                denseFeatures[vec.index] = vec.value;
150            }
151            featureMatrix.setRow(i, denseFeatures);
152        }
153
154        double[] featureMeans = new double[numFeatures];
155        double[] featureVariances = new double[numFeatures];
156        double[] outputMeans = new double[numOutputs];
157        double[] outputVariances = new double[numOutputs];
158        if (normalize) {
159            for (int i = 0; i < numFeatures; ++i) {
160                double[] featV = featureMatrix.getColumn(i);
161                featureMeans[i] = Util.mean(featV);
162
163                for (int j=0; j < featV.length; ++j) {
164                    featV[j] -= featureMeans[i];
165                }
166
167                RealVector xp = new ArrayRealVector(featV);
168                featureVariances[i] = xp.getNorm();
169                featureMatrix.setColumnVector(i,xp.mapDivideToSelf(featureVariances[i]));
170            }
171
172            for (int i = 0; i < numOutputs; i++) {
173                outputMeans[i] = Util.mean(outputs[i]);
174                // Remove mean and aggregate variance
175                double sum = 0.0;
176                for (int j = 0; j < numExamples; j++) {
177                    outputs[i][j] -= outputMeans[i];
178                    sum += outputs[i][j] * outputs[i][j];
179                }
180                outputVariances[i] = Math.sqrt(sum);
181                // Remove variance
182                for (int j = 0; j < numExamples; j++) {
183                    outputs[i][j] /= outputVariances[i];
184                }
185            }
186        } else {
187            Arrays.fill(featureMeans,0.0);
188            Arrays.fill(featureVariances,1.0);
189            Arrays.fill(outputMeans,0.0);
190            Arrays.fill(outputVariances,1.0);
191        }
192
193        // Construct the output matrix from the double[][] after scaling
194        RealMatrix outputMatrix = new Array2DRowRealMatrix(outputs);
195
196        // Array example is useful to compute a submatrix
197        int[] exampleRows = new int[numExamples];
198        for (int i = 0; i < numExamples; ++i) {
199            exampleRows[i] = i;
200        }
201
202        RealVector one = new ArrayRealVector(numExamples,1.0);
203
204        int numToSelect;
205        if ((maxNumFeatures < 1) || (maxNumFeatures > featureIDMap.size())) {
206            numToSelect = featureIDMap.size();
207        } else {
208            numToSelect = maxNumFeatures;
209        }
210
211        String[] dimensionNames = new String[numOutputs];
212        SparseVector[] modelWeights = new SparseVector[numOutputs];
213        for (Regressor r : domain) {
214            int id = outputInfo.getID(r);
215            dimensionNames[id] = r.getNames()[0];
216            SLMState state = new SLMState(featureMatrix,outputMatrix.getRowVector(id),featureIDMap,normalize);
217            modelWeights[id] = trainSingleDimension(state,exampleRows,numToSelect,one);
218        }
219
220        ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
221        return new SparseLinearModel("slm-model", dimensionNames, provenance, featureIDMap, outputInfo, modelWeights,
222                DenseVector.createDenseVector(featureMeans), DenseVector.createDenseVector(featureVariances),
223                outputMeans, outputVariances, !normalize);
224    }
225
226    @Override
227    public int getInvocationCount() {
228        return trainInvocationCounter;
229    }
230
231    @Override
232    public TrainerProvenance getProvenance() {
233        return new TrainerProvenanceImpl(this);
234    }
235
236    @Override
237    public String toString() {
238        return "SFSTrainer(normalize="+normalize+",maxNumFeatures="+maxNumFeatures+")";
239    }
240
241    /**
242     * Trains a single dimension.
243     * @param state The state object to use.
244     * @param exampleRows An array with the row indices in.
245     * @param numToSelect The number of features to select.
246     * @param one A RealVector of ones.
247     * @return The sparse vector representing the learned feature weights.
248     */
249    private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int numToSelect, RealVector one) {
250        int iter = 0;
251        while (state.active.size() < numToSelect) {
252            // Compute the residual
253            state.r = state.y.subtract(state.X.operate(state.beta));
254
255            logger.info("At iteration " + iter + " Average residual " + state.r.dotProduct(one) / state.numExamples);
256            iter++;
257            // Compute the correlation
258            state.corr = state.X.transpose().operate(state.r);
259
260            // Identify most correlated feature
261            double max = -1;
262            int feature = -1;
263            for (int i = 0; i < state.numFeatures; ++i) {
264                if (!state.activeSet.contains(i)) {
265                    double absCorr = Math.abs(state.corr.getEntry(i));
266
267                    if (absCorr > max) {
268                        max = absCorr;
269                        feature = i;
270                    }
271                }
272            }
273
274            state.C = max;
275
276            state.active.add(feature);
277            state.activeSet.add(feature);
278
279            if (!state.normalize && (feature == state.numFeatures-1)) {
280                logger.info("Bias selected");
281            } else {
282                logger.info("Feature selected: " + state.featureIDMap.get(feature).getName() + " (pos=" + feature + ")");
283            }
284
285            // Compute the active matrix
286            int[] activeFeatures = Util.toPrimitiveInt(state.active);
287            state.xpi = state.X.getSubMatrix(exampleRows, activeFeatures);
288
289            if (state.active.size() == (numToSelect - 1)) {
290                state.last = true;
291            }
292
293            RealVector betapi = newWeights(state);
294
295            if (betapi == null) {
296                // Matrix was not invertible
297                logger.log(Level.INFO, "Stopping at feature " + state.active.size() + " matrix was no longer invertible.");
298                break;
299            }
300
301            state.beta = betapi;
302        }
303
304        Map<Integer, Double> parameters = new HashMap<>();
305
306        for (int i = 0; i < state.numFeatures; ++i) {
307            if (state.beta.getEntry(i) != 0) {
308                parameters.put(i, state.beta.getEntry(i));
309            }
310        }
311
312        return SparseVector.createSparseVector(state.numFeatures, parameters);
313    }
314
315    /**
316     * Minimize ordinary least squares.
317     *
318     * Returns null if the matrix is not invertible.
319     * @param M The matrix of features.
320     * @param target The vector of target values.
321     * @return The OLS solution for the supplied features.
322     */
323    static RealVector ordinaryLeastSquares(RealMatrix M, RealVector target) {
324        RealMatrix inv;
325        try {
326            inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse();
327        } catch (SingularMatrixException s) {
328            // Matrix is not invertible, there is nothing we can do
329            // We will let the caller decide what to do
330            return null;
331        }
332
333        return inv.multiply(M.transpose()).operate(target);
334    }
335
336    /**
337     * Sums inverted matrix.
338     * @param matrix The Matrix to operate on.
339     * @return The sum of the inverted matrix.
340     */
341    static double sumInverted(RealMatrix matrix) {
342        // Why are we not trying to catch the potential exception?
343        // Because in the context of LARS, if we call this method, we know the matrix is invertible
344        RealMatrix inv = new LUDecomposition(matrix.transpose().multiply(matrix)).getSolver().getInverse();
345
346        RealVector one = new ArrayRealVector(matrix.getColumnDimension(),1.0);
347
348        return one.dotProduct(inv.operate(one));
349    }
350
351    /**
352     * Inverts the matrix, takes the dot product and scales it by the supplied value.
353     * @param M The matrix to invert.
354     * @param AA The value to scale by.
355     * @return The vector of feature values.
356     */
357    static RealVector getwa(RealMatrix M, double AA) {
358        RealMatrix inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse();
359        RealVector one = new ArrayRealVector(M.getColumnDimension(),1.0);
360
361        return inv.operate(one).mapMultiply(AA);
362    }
363
364    /**
365     * Calculates (M . v) . D^T
366     * Used in LARS.
367     * @param D A matrix.
368     * @param M A matrix.
369     * @param v A vector.
370     * @return (M . v) . D^T
371     */
372    static RealVector getA(RealMatrix D, RealMatrix M, RealVector v) {
373        RealVector u = M.operate(v);
374        return D.transpose().operate(u);
375    }
376
377    static class SLMState {
378        protected final int numExamples;
379        protected final int numFeatures;
380        protected final boolean normalize;
381        protected final ImmutableFeatureMap featureIDMap;
382
383        protected final Set<Integer> activeSet;
384        protected final List<Integer> active;
385
386        protected final RealMatrix X;
387        protected final RealVector y;
388
389        protected RealMatrix xpi;
390        protected RealVector r;
391        protected RealVector beta;
392
393        protected double C;
394        protected RealVector corr;
395
396        protected Boolean last = false;
397
398        public SLMState(RealMatrix features, RealVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) {
399            this.numExamples = features.getRowDimension();
400            this.numFeatures = features.getColumnDimension();
401            this.featureIDMap = featureIDMap;
402            this.normalize = normalize;
403            this.active = new ArrayList<>();
404            this.activeSet = new HashSet<>();
405            this.beta = new ArrayRealVector(numFeatures);
406            this.X = features;
407            this.y = outputs;
408        }
409
410        /**
411         * Unpacks the active set into a dense vector using the values in values
412         * @param values The values.
413         * @return A dense vector representing the values at the active set indices.
414         */
415        public RealVector unpack(RealVector values) {
416            RealVector u = new ArrayRealVector(numFeatures);
417
418            for (int i = 0; i < active.size(); ++i) {
419                u.setEntry(active.get(i), values.getEntry(i));
420            }
421
422            return u;
423        }
424    }
425}