001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.libsvm; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.Trainer; 028import org.tribuo.common.libsvm.KernelType; 029import org.tribuo.common.libsvm.SVMParameters; 030import org.tribuo.data.DataOptions; 031import org.tribuo.regression.RegressionFactory; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.evaluation.RegressionEvaluation; 034import org.tribuo.regression.libsvm.SVMRegressionType.SVMMode; 035import org.tribuo.util.Util; 036 037import java.io.IOException; 038import java.util.logging.Logger; 039 040/** 041 * Build and run a LibSVM regressor for a standard dataset. 042 */ 043public class TrainTest { 044 045 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 046 047 public static class LibLinearOptions implements Options { 048 @Override 049 public String getOptionsDescription() { 050 return "Trains and tests a LibSVM regression model on the specified datasets."; 051 } 052 public DataOptions general; 053 054 @Option(longName="coefficient",usage="Intercept in kernel function.") 055 public double coeff = 1.0; 056 @Option(charName='d',longName="degree",usage="Degree in polynomial kernel.") 057 public int degree = 3; 058 @Option(charName='g',longName="gamma",usage="Gamma value in kernel function.") 059 public double gamma = 0.0; 060 @Option(charName='k',longName="kernel",usage="Type of SVM kernel.") 061 public KernelType kernelType = KernelType.LINEAR; 062 @Option(charName='t',longName="type",usage="Type of SVM.") 063 public SVMRegressionType.SVMMode svmType = SVMMode.EPSILON_SVR; 064 } 065 066 /** 067 * @param args the command line arguments 068 * @throws IOException if there is any error reading the examples. 069 */ 070 public static void main(String[] args) throws IOException { 071 072 // 073 // Use the labs format logging. 074 LabsLogFormatter.setAllLogFormatters(); 075 076 LibLinearOptions o = new LibLinearOptions(); 077 ConfigurationManager cm; 078 try { 079 cm = new ConfigurationManager(args,o); 080 } catch (UsageException e) { 081 logger.info(e.getMessage()); 082 return; 083 } 084 085 if (o.general.trainingPath == null || o.general.testingPath == null) { 086 logger.info(cm.usage()); 087 return; 088 } 089 090 RegressionFactory factory = new RegressionFactory(); 091 092 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 093 Dataset<Regressor> train = data.getA(); 094 Dataset<Regressor> test = data.getB(); 095 096 SVMParameters<Regressor> parameters = new SVMParameters<>(new SVMRegressionType(o.svmType), o.kernelType); 097 parameters.setGamma(o.gamma); 098 parameters.setCoeff(o.coeff); 099 parameters.setDegree(o.degree); 100 Trainer<Regressor> trainer = new LibSVMRegressionTrainer(parameters); 101 logger.info("Training using " + trainer.toString()); 102 103 final long trainStart = System.currentTimeMillis(); 104 Model<Regressor> model = trainer.train(train); 105 final long trainStop = System.currentTimeMillis(); 106 107 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 108 logger.info("Support vectors - " + ((LibSVMRegressionModel)model).getNumberOfSupportVectors()); 109 110 final long testStart = System.currentTimeMillis(); 111 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 112 final long testStop = System.currentTimeMillis(); 113 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 114 System.out.println(evaluation.toString()); 115 116 if (o.general.outputPath != null) { 117 o.general.saveModel(model); 118 } 119 } 120}