001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.slm;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.SparseModel;
027import org.tribuo.SparseTrainer;
028import org.tribuo.data.DataOptions;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.regression.RegressionFactory;
031import org.tribuo.regression.Regressor;
032import org.tribuo.regression.evaluation.RegressionEvaluation;
033import org.tribuo.util.Util;
034
035import java.io.IOException;
036import java.util.Map;
037import java.util.logging.Logger;
038
039/**
040 * Build and run a sparse linear regression model for a standard dataset.
041 */
042public class TrainTest {
043
044    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
045
046    public enum SLMType { SFS, SFSN, LARS, LARSLASSO, ELASTICNET }
047
048    public static class LARSOptions implements Options {
049        @Override
050        public String getOptionsDescription() {
051            return "Trains and tests a sparse linear regression model on the specified datasets.";
052        }
053        public DataOptions general;
054
055        @Option(charName='m',longName="max-features-num", usage="Set the maximum number of features.")
056        public int maxNumFeatures = -1;
057        @Option(charName='a',longName="algorithm", usage="Choose the training algorithm (stepwise forward selection or least angle regression).")
058        public SLMType algorithm = SLMType.LARS;
059        @Option(charName='b',longName="alpha", usage="Regularisation strength in the Elastic Net.")
060        public double alpha = 1.0;
061        @Option(charName='l',longName="l1Ratio", usage="Ratio between the l1 and l2 penalties in the Elastic Net. Must be between 0 and 1.")
062        public double l1Ratio = 1.0;
063        @Option(longName="iterations",usage="Iterations of Elastic Net.")
064        public int iterations = 500;
065    }
066
067    /**
068     * @param args the command line arguments
069     * @throws IOException if there is any error reading the examples.
070     */
071    public static void main(String[] args) throws IOException {
072        //
073        // Use the labs format logging.
074        LabsLogFormatter.setAllLogFormatters();
075
076        LARSOptions o = new LARSOptions();
077        ConfigurationManager cm;
078        try {
079            cm = new ConfigurationManager(args,o);
080        } catch (UsageException e) {
081            logger.info(e.getMessage());
082            return;
083        }
084
085        if (o.general.trainingPath == null || o.general.testingPath == null) {
086            logger.info(cm.usage());
087            return;
088        }
089
090        RegressionFactory factory = new RegressionFactory();
091
092        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
093        Dataset<Regressor> train = data.getA();
094        Dataset<Regressor> test = data.getB();
095
096        SparseTrainer<Regressor> trainer;
097
098        switch (o.algorithm) {
099            case SFS:
100                trainer = new SLMTrainer(false,Math.min(train.getFeatureMap().size(),o.maxNumFeatures));
101                break;
102            case LARS:
103                trainer = new LARSTrainer(Math.min(train.getFeatureMap().size(),o.maxNumFeatures));
104                break;
105            case LARSLASSO:
106                trainer = new LARSLassoTrainer(Math.min(train.getFeatureMap().size(),o.maxNumFeatures));
107                break;
108            case SFSN:
109                trainer = new SLMTrainer(true,Math.min(train.getFeatureMap().size(),o.maxNumFeatures));
110                break;
111            case ELASTICNET:
112                trainer = new ElasticNetCDTrainer(o.alpha,o.l1Ratio,1e-4,o.iterations,false,o.general.seed);
113                break;
114            default:
115                logger.warning("Unknown SLMType, found " + o.algorithm);
116                return;
117        }
118
119        logger.info("Training using " + trainer.toString());
120        final long trainStart = System.currentTimeMillis();
121        SparseModel<Regressor> model = trainer.train(train);
122        final long trainStop = System.currentTimeMillis();
123        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
124        logger.info("Selected features: " + model.getActiveFeatures());
125        Map<String,SparseVector> weights = ((SparseLinearModel)model).getWeights();
126        for (Map.Entry<String,SparseVector> e : weights.entrySet()) {
127            logger.info("Target:" + e.getKey());
128            logger.info("\tWeights: " + e.getValue());
129            logger.info("\tWeights one norm: " + e.getValue().oneNorm());
130            logger.info("\tWeights two norm: " + e.getValue().twoNorm());
131        }
132        final long testStart = System.currentTimeMillis();
133        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
134        final long testStop = System.currentTimeMillis();
135        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
136        System.out.println(evaluation.toString());
137
138        if (o.general.outputPath != null) {
139            o.general.saveModel(model);
140        }
141    }
142}