001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import org.apache.commons.math3.linear.RealVector;
020
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.List;
024import java.util.logging.Logger;
025
026/**
027 * A trainer for a linear regression model which uses least angle regression.
028 * Each output dimension is trained independently.
029 * <p>
030 * See:
031 * <pre>
032 * Efron B, Hastie T, Johnstone I, Tibshirani R.
033 * "Least Angle Regression"
034 * The Annals of Statistics, 2004.
035 * </pre>
036 */
037public class LARSTrainer extends SLMTrainer {
038    private static final Logger logger = Logger.getLogger(LARSTrainer.class.getName());
039
040    /**
041     * Constructs a least angle regression trainer for a linear model.
042     * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features.
043     */
044    public LARSTrainer(int maxNumFeatures) {
045        super(true,maxNumFeatures);
046    }
047
048    /**
049     * Constructs a least angle regression trainer that selects all the features.
050     */
051    public LARSTrainer() {
052        this(-1);
053    }
054
055    @Override
056    protected RealVector newWeights(SLMState state) {
057        if (state.last) {
058            return super.newWeights(state);
059        }
060
061        RealVector deltapi =  SLMTrainer.ordinaryLeastSquares(state.xpi,state.r);
062
063        if (deltapi == null) {
064            return null;
065        }
066
067        RealVector delta = state.unpack(deltapi);
068
069        // Computing gamma
070        List<Double> candidates = new ArrayList<>();
071
072        double AA = SLMTrainer.sumInverted(state.xpi);
073        double CC = state.C;
074
075        RealVector wa = SLMTrainer.getwa(state.xpi,AA);
076        RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa);
077
078        for (int i = 0; i < state.numFeatures; ++i) {
079            if (!state.activeSet.contains(i)) {
080                double c = state.corr.getEntry(i);
081                double a = ar.getEntry(i);
082
083                double v1 = (CC - c) / (AA - a);
084                double v2 = (CC + c) / (AA + a);
085
086                if (v1 >= 0) {
087                    candidates.add(v1);
088                }
089                if (v2 >= 0) {
090                    candidates.add(v2);
091                }
092            }
093        }
094
095        double gamma = Collections.min(candidates);
096
097        return state.beta.add(delta.mapMultiplyToSelf(gamma));
098    }
099
100    @Override
101    public String toString() {
102        return "LARSTrainer(maxNumFeatures="+maxNumFeatures+")";
103    }
104}