001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import org.apache.commons.math3.linear.RealVector;
020
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.List;
024import java.util.logging.Logger;
025
026/**
027 * A trainer for a lasso linear regression model which uses LARS to construct the model.
028 * Each output dimension is trained independently.
029 * <p>
030 * See:
031 * <pre>
032 * Efron B, Hastie T, Johnstone I, Tibshirani R.
033 * "Least Angle Regression"
034 * The Annals of Statistics, 2004.
035 * </pre>
036 */
037public class LARSLassoTrainer extends SLMTrainer {
038    private static final Logger logger = Logger.getLogger(LARSLassoTrainer.class.getName());
039
040    /**
041     * Constructs a lasso LARS trainer for a linear model.
042     * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features.
043     */
044    public LARSLassoTrainer(int maxNumFeatures) {
045        super(true,maxNumFeatures);
046    }
047
048    /**
049     * Constructs a lasso LARS trainer that selects all the features.
050     */
051    public LARSLassoTrainer() {
052        this(-1);
053    }
054
055    @Override
056    protected RealVector newWeights(SLMState state) {
057        if (state.last) {
058            return super.newWeights(state);
059        }
060
061        RealVector deltapi =  SLMTrainer.ordinaryLeastSquares(state.xpi,state.r);
062
063        if (deltapi == null) {
064            return null;
065        }
066
067        RealVector delta = state.unpack(deltapi);
068
069        // Computing gamma
070        List<Double> candidates = new ArrayList<>();
071
072        double AA = SLMTrainer.sumInverted(state.xpi);
073        double CC = state.C;
074
075        RealVector wa = SLMTrainer.getwa(state.xpi,AA);
076        RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa);
077
078        for (int i = 0; i < state.numFeatures; ++i) {
079            if (!state.activeSet.contains(i)) {
080                double c = state.corr.getEntry(i);
081                double a = ar.getEntry(i);
082
083                double v1 = (CC - c) / (AA - a);
084                double v2 = (CC + c) / (AA + a);
085
086                if (v1 >= 0) {
087                    candidates.add(v1);
088                }
089                if (v2 >= 0) {
090                    candidates.add(v2);
091                }
092            }
093        }
094
095        double gamma = Collections.min(candidates);
096
097//        // The lasso modification
098//        if (active.size() >= 2) {
099//            int min = active.get(0);
100//            double min_gamma = - beta.getEntry(min) / (wa.getEntry(active.indexOf(new Integer(min))) * (corr.getEntry(min) >= 0 ? +1 : -1));
101//
102//            for (int i = 1; i < active.size()-1; ++i) {
103//                int idx = active.get(i);
104//                double gamma_i = - beta.getEntry(idx) / (wa.getEntry(active.indexOf(new Integer(idx))) * (corr.getEntry(idx) >= 0 ? +1 : -1));
105//                if (gamma_i < 0) continue;
106//                if (gamma_i < min) {
107//                    min = i;
108//                    min_gamma = gamma_i;
109//                }
110//            }
111//
112//            if (min_gamma < gamma) {
113//                active.remove(new Integer(min));
114//                beta.setEntry(min,0.0);
115//                return beta.add(delta.mapMultiplyToSelf(min_gamma));
116//            }
117//        }
118//
119//        return beta.add(delta.mapMultiplyToSelf(gamma));
120
121        RealVector other = delta.mapMultiplyToSelf(gamma);
122
123        for (int i = 0; i < state.numFeatures; ++i) {
124            double betaElement = state.beta.getEntry(i);
125            double otherElement = other.getEntry(i);
126            if ((betaElement > 0 && betaElement + otherElement < 0)
127                    || (betaElement < 0 && betaElement + otherElement > 0)) {
128                state.beta.setEntry(i,0.0);
129                other.setEntry(i,0.0);
130                Integer integer = i;
131                state.active.remove(integer);
132                state.activeSet.remove(integer);
133            }
134        }
135
136        return state.beta.add(other);
137    }
138
139    @Override
140    public String toString() {
141        return "LARSLassoTrainer(maxNumFeatures="+maxNumFeatures+")";
142    }
143}