001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import org.apache.commons.math3.linear.RealVector; 020 021import java.util.ArrayList; 022import java.util.Collections; 023import java.util.List; 024import java.util.logging.Logger; 025 026/** 027 * A trainer for a linear regression model which uses least angle regression. 028 * Each output dimension is trained independently. 029 * <p> 030 * See: 031 * <pre> 032 * Efron B, Hastie T, Johnstone I, Tibshirani R. 033 * "Least Angle Regression" 034 * The Annals of Statistics, 2004. 035 * </pre> 036 */ 037public class LARSTrainer extends SLMTrainer { 038 private static final Logger logger = Logger.getLogger(LARSTrainer.class.getName()); 039 040 /** 041 * Constructs a least angle regression trainer for a linear model. 042 * @param maxNumFeatures The maximum number of features to select. Supply -1 to select all features. 043 */ 044 public LARSTrainer(int maxNumFeatures) { 045 super(true,maxNumFeatures); 046 } 047 048 /** 049 * Constructs a least angle regression trainer that selects all the features. 050 */ 051 public LARSTrainer() { 052 this(-1); 053 } 054 055 @Override 056 protected RealVector newWeights(SLMState state) { 057 if (state.last) { 058 return super.newWeights(state); 059 } 060 061 RealVector deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); 062 063 if (deltapi == null) { 064 return null; 065 } 066 067 RealVector delta = state.unpack(deltapi); 068 069 // Computing gamma 070 List<Double> candidates = new ArrayList<>(); 071 072 double AA = SLMTrainer.sumInverted(state.xpi); 073 double CC = state.C; 074 075 RealVector wa = SLMTrainer.getwa(state.xpi,AA); 076 RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa); 077 078 for (int i = 0; i < state.numFeatures; ++i) { 079 if (!state.activeSet.contains(i)) { 080 double c = state.corr.getEntry(i); 081 double a = ar.getEntry(i); 082 083 double v1 = (CC - c) / (AA - a); 084 double v2 = (CC + c) / (AA + a); 085 086 if (v1 >= 0) { 087 candidates.add(v1); 088 } 089 if (v2 >= 0) { 090 candidates.add(v2); 091 } 092 } 093 } 094 095 double gamma = Collections.min(candidates); 096 097 return state.beta.add(delta.mapMultiplyToSelf(gamma)); 098 } 099 100 @Override 101 public String toString() { 102 return "LARSTrainer(maxNumFeatures="+maxNumFeatures+")"; 103 } 104}