001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.SparseModel; 027import org.tribuo.SparseTrainer; 028import org.tribuo.data.DataOptions; 029import org.tribuo.regression.RegressionFactory; 030import org.tribuo.regression.Regressor; 031import org.tribuo.regression.evaluation.RegressionEvaluation; 032import org.tribuo.regression.rtree.impurity.MeanAbsoluteError; 033import org.tribuo.regression.rtree.impurity.MeanSquaredError; 034import org.tribuo.regression.rtree.impurity.RegressorImpurity; 035import org.tribuo.util.Util; 036 037import java.io.IOException; 038import java.util.logging.Logger; 039 040/** 041 * Build and run a regression tree for a standard dataset. 042 */ 043public class TrainTest { 044 045 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 046 047 public enum ImpurityType { MSE, MAE } 048 049 public enum TreeType {CART_INDEPENDENT, CART_JOINT} 050 051 public static class DecisionTreeOptions implements Options { 052 @Override 053 public String getOptionsDescription() { 054 return "Trains and tests a CART regression model on the specified datasets."; 055 } 056 public DataOptions general; 057 @Option(longName="csv-response-split-char",usage="Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.") 058 public char splitChar = ':'; 059 @Option(charName='d',longName="max-depth",usage="Maximum depth in the decision tree.") 060 public int depth = 6; 061 @Option(charName='e',longName="split-fraction",usage="Fraction of features in split.") 062 public float fraction = 0.0f; 063 @Option(charName='m',longName="min-child-weight",usage="Minimum child weight.") 064 public float minChildWeight = 5.0f; 065 @Option(charName='n',longName="normalize",usage="Normalize the leaf outputs so each leaf sums to 1.0.") 066 public boolean normalize = false; 067 @Option(charName='i',longName="impurity",usage="Impurity measure to use. Defaults to MSE.") 068 public ImpurityType impurityType = ImpurityType.MSE; 069 @Option(charName='t',longName="tree-type",usage="Tree type.") 070 public TreeType treeType = TreeType.CART_INDEPENDENT; 071 @Option(longName="print-tree",usage="Prints the decision tree.") 072 public boolean printTree; 073 } 074 075 /** 076 * @param args the command line arguments 077 * @throws IOException if there is any error reading the examples. 078 */ 079 public static void main(String[] args) throws IOException { 080 081 // 082 // Use the labs format logging. 083 LabsLogFormatter.setAllLogFormatters(); 084 085 DecisionTreeOptions o = new DecisionTreeOptions(); 086 ConfigurationManager cm; 087 try { 088 cm = new ConfigurationManager(args,o); 089 } catch (UsageException e) { 090 logger.info(e.getMessage()); 091 return; 092 } 093 094 RegressionFactory factory = new RegressionFactory(o.splitChar); 095 096 Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory); 097 Dataset<Regressor> train = data.getA(); 098 Dataset<Regressor> test = data.getB(); 099 100 RegressorImpurity impurity; 101 switch (o.impurityType) { 102 case MAE: 103 impurity = new MeanAbsoluteError(); 104 break; 105 case MSE: 106 impurity = new MeanSquaredError(); 107 break; 108 default: 109 logger.severe("unknown impurity type " + o.impurityType); 110 return; 111 } 112 113 if (o.general.trainingPath == null || o.general.testingPath == null) { 114 logger.info(cm.usage()); 115 return; 116 } 117 118 SparseTrainer<Regressor> trainer; 119 switch (o.treeType) { 120 case CART_INDEPENDENT: 121 if (o.fraction <= 0) { 122 trainer = new CARTRegressionTrainer(o.depth,o.minChildWeight,1, impurity, o.general.seed); 123 } else { 124 trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.fraction, impurity, o.general.seed); 125 } 126 break; 127 case CART_JOINT: 128 if (o.fraction <= 0) { 129 trainer = new CARTJointRegressionTrainer(o.depth,o.minChildWeight,1, impurity, o.normalize, o.general.seed); 130 } else { 131 trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.fraction, impurity, o.normalize, o.general.seed); 132 } 133 break; 134 default: 135 logger.severe("unknown tree type " + o.treeType); 136 return; 137 } 138 139 logger.info("Training using " + trainer.toString()); 140 141 final long trainStart = System.currentTimeMillis(); 142 SparseModel<Regressor> model = trainer.train(train); 143 final long trainStop = System.currentTimeMillis(); 144 145 logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop)); 146 147 if (o.printTree) { 148 logger.info(model.toString()); 149 } 150 151 logger.info("Selected features: " + model.getActiveFeatures()); 152 final long testStart = System.currentTimeMillis(); 153 RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test); 154 final long testStop = System.currentTimeMillis(); 155 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 156 System.out.println(evaluation.toString()); 157 158 if (o.general.outputPath != null) { 159 o.general.saveModel(model); 160 } 161 } 162}