001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.libsvm;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Model;
027import org.tribuo.Trainer;
028import org.tribuo.common.libsvm.KernelType;
029import org.tribuo.common.libsvm.SVMParameters;
030import org.tribuo.data.DataOptions;
031import org.tribuo.regression.RegressionFactory;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.evaluation.RegressionEvaluation;
034import org.tribuo.regression.libsvm.SVMRegressionType.SVMMode;
035import org.tribuo.util.Util;
036
037import java.io.IOException;
038import java.util.logging.Logger;
039
040/**
041 * Build and run a LibSVM regressor for a standard dataset.
042 */
043public class TrainTest {
044
045    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
046
047    public static class LibLinearOptions implements Options {
048        @Override
049        public String getOptionsDescription() {
050            return "Trains and tests a LibSVM regression model on the specified datasets.";
051        }
052        public DataOptions general;
053
054        @Option(longName="coefficient",usage="Intercept in kernel function.")
055        public double coeff = 1.0;
056        @Option(charName='d',longName="degree",usage="Degree in polynomial kernel.")
057        public int degree = 3;
058        @Option(charName='g',longName="gamma",usage="Gamma value in kernel function.")
059        public double gamma = 0.0;
060        @Option(charName='k',longName="kernel",usage="Type of SVM kernel.")
061        public KernelType kernelType = KernelType.LINEAR;
062        @Option(charName='t',longName="type",usage="Type of SVM.")
063        public SVMRegressionType.SVMMode svmType = SVMMode.EPSILON_SVR;
064    }
065
066    /**
067     * @param args the command line arguments
068     * @throws IOException if there is any error reading the examples.
069     */
070    public static void main(String[] args) throws IOException {
071
072        //
073        // Use the labs format logging.
074        LabsLogFormatter.setAllLogFormatters();
075
076        LibLinearOptions o = new LibLinearOptions();
077        ConfigurationManager cm;
078        try {
079            cm = new ConfigurationManager(args,o);
080        } catch (UsageException e) {
081            logger.info(e.getMessage());
082            return;
083        }
084
085        if (o.general.trainingPath == null || o.general.testingPath == null) {
086            logger.info(cm.usage());
087            return;
088        }
089
090        RegressionFactory factory = new RegressionFactory();
091
092        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
093        Dataset<Regressor> train = data.getA();
094        Dataset<Regressor> test = data.getB();
095
096        SVMParameters<Regressor> parameters = new SVMParameters<>(new SVMRegressionType(o.svmType), o.kernelType);
097        parameters.setGamma(o.gamma);
098        parameters.setCoeff(o.coeff);
099        parameters.setDegree(o.degree);
100        Trainer<Regressor> trainer = new LibSVMRegressionTrainer(parameters);
101        logger.info("Training using " + trainer.toString());
102
103        final long trainStart = System.currentTimeMillis();
104        Model<Regressor> model = trainer.train(train);
105        final long trainStop = System.currentTimeMillis();
106
107        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
108        logger.info("Support vectors - " + ((LibSVMRegressionModel)model).getNumberOfSupportVectors());
109
110        final long testStart = System.currentTimeMillis();
111        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
112        final long testStop = System.currentTimeMillis();
113        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
114        System.out.println(evaluation.toString());
115
116        if (o.general.outputPath != null) {
117            o.general.saveModel(model);
118        }
119    }
120}