001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.liblinear;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.Trainer;
028import org.tribuo.data.DataOptions;
029import org.tribuo.regression.RegressionFactory;
030import org.tribuo.regression.Regressor;
031import org.tribuo.regression.evaluation.RegressionEvaluation;
032import org.tribuo.regression.liblinear.LinearRegressionType.LinearType;
033import org.tribuo.util.Util;
034
035import java.io.IOException;
036import java.util.logging.Logger;
037
038/**
039 * Build and run a LibLinear regressor for a standard dataset.
040 */
041public class TrainTest {
042
043    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
044
045    public static class LibLinearOptions implements Options {
046        @Override
047        public String getOptionsDescription() {
048            return "Trains and tests a LibLinear regression model on the specified datasets.";
049        }
050        public DataOptions general;
051
052        @Option(charName='p',longName="cost-penalty",usage="Cost penalty for SVM.")
053        public double cost = 1.0;
054        @Option(longName = "max-iterations", usage = "Max iterations over the data.")
055        public int maxIterations = 1000;
056        @Option(longName="epsilon-insensitivity",usage="Regression value insensitivity for margin.")
057        public double epsilon = 0.1;
058        @Option(charName='e',longName="termination-criterion",usage="Tolerance of the optimization termination criterion.")
059        public double terminationCriterion = 0.01;
060        @Option(charName='t',longName="algorithm",usage="Type of SVR.")
061        public LinearType algorithm = LinearType.L2R_L2LOSS_SVR;
062    }
063
064    /**
065     * @param args the command line arguments
066     * @throws IOException if there is any error reading the examples.
067     */
068    public static void main(String[] args) throws IOException {
069
070        //
071        // Use the labs format logging.
072        LabsLogFormatter.setAllLogFormatters();
073
074        LibLinearOptions o = new LibLinearOptions();
075        ConfigurationManager cm;
076        try {
077            cm = new ConfigurationManager(args,o);
078        } catch (UsageException e) {
079            logger.info(e.getMessage());
080            return;
081        }
082
083        if (o.general.trainingPath == null || o.general.testingPath == null) {
084            logger.info(cm.usage());
085            return;
086        }
087
088        RegressionFactory factory = new RegressionFactory();
089
090        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
091        Dataset<Regressor> train = data.getA();
092        Dataset<Regressor> test = data.getB();
093
094        Trainer<Regressor> trainer = new LibLinearRegressionTrainer(new LinearRegressionType(o.algorithm),o.cost,o.maxIterations,o.terminationCriterion,o.epsilon);
095        logger.info("Training using " + trainer.toString());
096
097        final long trainStart = System.currentTimeMillis();
098        Model<Regressor> model = trainer.train(train);
099        final long trainStop = System.currentTimeMillis();
100
101        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
102
103        final long testStart = System.currentTimeMillis();
104        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
105        final long testStop = System.currentTimeMillis();
106        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
107        System.out.println(evaluation.toString());
108
109        if (o.general.outputPath != null) {
110            o.general.saveModel(model);
111        }
112    }
113}