001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import org.apache.commons.math3.linear.RealVector; 020 021import java.util.ArrayList; 022import java.util.Collections; 023import java.util.List; 024import java.util.logging.Logger; 025 026/** 027 * A trainer for a lasso linear regression model which uses LARS to construct the model. 028 * Each output dimension is trained independently. 029 * <p> 030 * See: 031 * <pre> 032 * Efron B, Hastie T, Johnstone I, Tibshirani R. 033 * "Least Angle Regression" 034 * The Annals of Statistics, 2004. 035 * </pre> 036 */ 037public class LARSLassoTrainer extends SLMTrainer { 038 private static final Logger logger = Logger.getLogger(LARSLassoTrainer.class.getName()); 039 040 /** 041 * Constructs a lasso LARS trainer for a linear model. 042 * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features. 043 */ 044 public LARSLassoTrainer(int maxNumFeatures) { 045 super(true,maxNumFeatures); 046 } 047 048 /** 049 * Constructs a lasso LARS trainer that selects all the features. 050 */ 051 public LARSLassoTrainer() { 052 this(-1); 053 } 054 055 @Override 056 protected RealVector newWeights(SLMState state) { 057 if (state.last) { 058 return super.newWeights(state); 059 } 060 061 RealVector deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); 062 063 if (deltapi == null) { 064 return null; 065 } 066 067 RealVector delta = state.unpack(deltapi); 068 069 // Computing gamma 070 List<Double> candidates = new ArrayList<>(); 071 072 double AA = SLMTrainer.sumInverted(state.xpi); 073 double CC = state.C; 074 075 RealVector wa = SLMTrainer.getwa(state.xpi,AA); 076 RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa); 077 078 for (int i = 0; i < state.numFeatures; ++i) { 079 if (!state.activeSet.contains(i)) { 080 double c = state.corr.getEntry(i); 081 double a = ar.getEntry(i); 082 083 double v1 = (CC - c) / (AA - a); 084 double v2 = (CC + c) / (AA + a); 085 086 if (v1 >= 0) { 087 candidates.add(v1); 088 } 089 if (v2 >= 0) { 090 candidates.add(v2); 091 } 092 } 093 } 094 095 double gamma = Collections.min(candidates); 096 097// // The lasso modification 098// if (active.size() >= 2) { 099// int min = active.get(0); 100// double min_gamma = - beta.getEntry(min) / (wa.getEntry(active.indexOf(new Integer(min))) * (corr.getEntry(min) >= 0 ? +1 : -1)); 101// 102// for (int i = 1; i < active.size()-1; ++i) { 103// int idx = active.get(i); 104// double gamma_i = - beta.getEntry(idx) / (wa.getEntry(active.indexOf(new Integer(idx))) * (corr.getEntry(idx) >= 0 ? +1 : -1)); 105// if (gamma_i < 0) continue; 106// if (gamma_i < min) { 107// min = i; 108// min_gamma = gamma_i; 109// } 110// } 111// 112// if (min_gamma < gamma) { 113// active.remove(new Integer(min)); 114// beta.setEntry(min,0.0); 115// return beta.add(delta.mapMultiplyToSelf(min_gamma)); 116// } 117// } 118// 119// return beta.add(delta.mapMultiplyToSelf(gamma)); 120 121 RealVector other = delta.mapMultiplyToSelf(gamma); 122 123 for (int i = 0; i < state.numFeatures; ++i) { 124 double betaElement = state.beta.getEntry(i); 125 double otherElement = other.getEntry(i); 126 if ((betaElement > 0 && betaElement + otherElement < 0) 127 || (betaElement < 0 && betaElement + otherElement > 0)) { 128 state.beta.setEntry(i,0.0); 129 other.setEntry(i,0.0); 130 Integer integer = i; 131 state.active.remove(integer); 132 state.activeSet.remove(integer); 133 } 134 } 135 136 return state.beta.add(other); 137 } 138 139 @Override 140 public String toString() { 141 return "LARSLassoTrainer(maxNumFeatures="+maxNumFeatures+")"; 142 } 143}