001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.liblinear; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.Trainer; 028import org.tribuo.data.DataOptions; 029import org.tribuo.regression.RegressionFactory; 030import org.tribuo.regression.Regressor; 031import org.tribuo.regression.evaluation.RegressionEvaluation; 032import org.tribuo.regression.liblinear.LinearRegressionType.LinearType; 033import org.tribuo.util.Util; 034 035import java.io.IOException; 036import java.util.logging.Logger; 037 038/** 039 * Build and run a LibLinear regressor for a standard dataset. 040 */ 041public class TrainTest { 042 043 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 044 045 public static class LibLinearOptions implements Options { 046 @Override 047 public String getOptionsDescription() { 048 return "Trains and tests a LibLinear regression model on the specified datasets."; 049 } 050 public DataOptions general; 051 052 @Option(charName='p',longName="cost-penalty",usage="Cost penalty for SVM.") 053 public double cost = 1.0; 054 @Option(longName = "max-iterations", usage = "Max iterations over the data.") 055 public int maxIterations = 1000; 056 @Option(longName="epsilon-insensitivity",usage="Regression value insensitivity for margin.") 057 public double epsilon = 0.1; 058 @Option(charName='e',longName="termination-criterion",usage="Tolerance of the optimization termination criterion.") 059 public double terminationCriterion = 0.01; 060 @Option(charName='t',longName="algorithm",usage="Type of SVR.") 061 public LinearType algorithm = LinearType.L2R_L2LOSS_SVR; 062 } 063 064 /** 065 * @param args the command line arguments 066 * @throws IOException if there is any error reading the examples. 067 */ 068 public static void main(String[] args) throws IOException { 069 070 // 071 // Use the labs format logging. 072 LabsLogFormatter.setAllLogFormatters(); 073 074 LibLinearOptions o = new LibLinearOptions(); 075 ConfigurationManager cm; 076 try { 077 cm = new ConfigurationManager(args,o); 078 } catch (UsageException e) { 079 logger.info(e.getMessage()); 080 return; 081 } 082 083 if (o.general.trainingPath == null || o.general.testingPath == null) { 084 logger.info(cm.usage()); 085 return; 086 } 087 088 RegressionFactory factory = new RegressionFactory(); 089 090 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 091 Dataset<Regressor> train = data.getA(); 092 Dataset<Regressor> test = data.getB(); 093 094 Trainer<Regressor> trainer = new LibLinearRegressionTrainer(new LinearRegressionType(o.algorithm),o.cost,o.maxIterations,o.terminationCriterion,o.epsilon); 095 logger.info("Training using " + trainer.toString()); 096 097 final long trainStart = System.currentTimeMillis(); 098 Model<Regressor> model = trainer.train(train); 099 final long trainStop = System.currentTimeMillis(); 100 101 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 102 103 final long testStart = System.currentTimeMillis(); 104 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 105 final long testStop = System.currentTimeMillis(); 106 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 107 System.out.println(evaluation.toString()); 108 109 if (o.general.outputPath != null) { 110 o.general.saveModel(model); 111 } 112 } 113}