001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.slm; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.SparseModel; 027import org.tribuo.SparseTrainer; 028import org.tribuo.data.DataOptions; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.regression.RegressionFactory; 031import org.tribuo.regression.Regressor; 032import org.tribuo.regression.evaluation.RegressionEvaluation; 033import org.tribuo.util.Util; 034 035import java.io.IOException; 036import java.util.Map; 037import java.util.logging.Logger; 038 039/** 040 * Build and run a sparse linear regression model for a standard dataset. 041 */ 042public class TrainTest { 043 044 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 045 046 public enum SLMType { SFS, SFSN, LARS, LARSLASSO, ELASTICNET } 047 048 public static class LARSOptions implements Options { 049 @Override 050 public String getOptionsDescription() { 051 return "Trains and tests a sparse linear regression model on the specified datasets."; 052 } 053 public DataOptions general; 054 055 @Option(charName='m',longName="max-features-num", usage="Set the maximum number of features.") 056 public int maxNumFeatures = -1; 057 @Option(charName='a',longName="algorithm", usage="Choose the training algorithm (stepwise forward selection or least angle regression).") 058 public SLMType algorithm = SLMType.LARS; 059 @Option(charName='b',longName="alpha", usage="Regularisation strength in the Elastic Net.") 060 public double alpha = 1.0; 061 @Option(charName='l',longName="l1Ratio", usage="Ratio between the l1 and l2 penalties in the Elastic Net. Must be between 0 and 1.") 062 public double l1Ratio = 1.0; 063 @Option(longName="iterations",usage="Iterations of Elastic Net.") 064 public int iterations = 500; 065 } 066 067 /** 068 * @param args the command line arguments 069 * @throws IOException if there is any error reading the examples. 070 */ 071 public static void main(String[] args) throws IOException { 072 // 073 // Use the labs format logging. 074 LabsLogFormatter.setAllLogFormatters(); 075 076 LARSOptions o = new LARSOptions(); 077 ConfigurationManager cm; 078 try { 079 cm = new ConfigurationManager(args,o); 080 } catch (UsageException e) { 081 logger.info(e.getMessage()); 082 return; 083 } 084 085 if (o.general.trainingPath == null || o.general.testingPath == null) { 086 logger.info(cm.usage()); 087 return; 088 } 089 090 RegressionFactory factory = new RegressionFactory(); 091 092 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 093 Dataset<Regressor> train = data.getA(); 094 Dataset<Regressor> test = data.getB(); 095 096 SparseTrainer<Regressor> trainer; 097 098 switch (o.algorithm) { 099 case SFS: 100 trainer = new SLMTrainer(false,Math.min(train.getFeatureMap().size(),o.maxNumFeatures)); 101 break; 102 case LARS: 103 trainer = new LARSTrainer(Math.min(train.getFeatureMap().size(),o.maxNumFeatures)); 104 break; 105 case LARSLASSO: 106 trainer = new LARSLassoTrainer(Math.min(train.getFeatureMap().size(),o.maxNumFeatures)); 107 break; 108 case SFSN: 109 trainer = new SLMTrainer(true,Math.min(train.getFeatureMap().size(),o.maxNumFeatures)); 110 break; 111 case ELASTICNET: 112 trainer = new ElasticNetCDTrainer(o.alpha,o.l1Ratio,1e-4,o.iterations,false,o.general.seed); 113 break; 114 default: 115 logger.warning("Unknown SLMType, found " + o.algorithm); 116 return; 117 } 118 119 logger.info("Training using " + trainer.toString()); 120 final long trainStart = System.currentTimeMillis(); 121 SparseModel<Regressor> model = trainer.train(train); 122 final long trainStop = System.currentTimeMillis(); 123 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 124 logger.info("Selected features: " + model.getActiveFeatures()); 125 Map<String,SparseVector> weights = ((SparseLinearModel)model).getWeights(); 126 for (Map.Entry<String,SparseVector> e : weights.entrySet()) { 127 logger.info("Target:" + e.getKey()); 128 logger.info("\tWeights: " + e.getValue()); 129 logger.info("\tWeights one norm: " + e.getValue().oneNorm()); 130 logger.info("\tWeights two norm: " + e.getValue().twoNorm()); 131 } 132 final long testStart = System.currentTimeMillis(); 133 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 134 final long testStop = System.currentTimeMillis(); 135 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 136 System.out.println(evaluation.toString()); 137 138 if (o.general.outputPath != null) { 139 o.general.saveModel(model); 140 } 141 } 142}