001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.SparseModel;
027import org.tribuo.SparseTrainer;
028import org.tribuo.data.DataOptions;
029import org.tribuo.regression.RegressionFactory;
030import org.tribuo.regression.Regressor;
031import org.tribuo.regression.evaluation.RegressionEvaluation;
032import org.tribuo.regression.rtree.impurity.MeanAbsoluteError;
033import org.tribuo.regression.rtree.impurity.MeanSquaredError;
034import org.tribuo.regression.rtree.impurity.RegressorImpurity;
035import org.tribuo.util.Util;
036
037import java.io.IOException;
038import java.util.logging.Logger;
039
040/**
041 * Build and run a regression tree for a standard dataset.
042 */
043public class TrainTest {
044
045    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
046
047    public enum ImpurityType { MSE, MAE }
048
049    public enum TreeType {CART_INDEPENDENT, CART_JOINT}
050
051    public static class DecisionTreeOptions implements Options {
052        @Override
053        public String getOptionsDescription() {
054            return "Trains and tests a CART regression model on the specified datasets.";
055        }
056        public DataOptions general;
057        @Option(longName="csv-response-split-char",usage="Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.")
058        public char splitChar = ':';
059        @Option(charName='d',longName="max-depth",usage="Maximum depth in the decision tree.")
060        public int depth = 6;
061        @Option(charName='e',longName="split-fraction",usage="Fraction of features in split.")
062        public float fraction = 0.0f;
063        @Option(charName='m',longName="min-child-weight",usage="Minimum child weight.")
064        public float minChildWeight = 5.0f;
065        @Option(charName='n',longName="normalize",usage="Normalize the leaf outputs so each leaf sums to 1.0.")
066        public boolean normalize = false;
067        @Option(charName='i',longName="impurity",usage="Impurity measure to use. Defaults to MSE.")
068        public ImpurityType impurityType = ImpurityType.MSE;
069        @Option(charName='t',longName="tree-type",usage="Tree type.")
070        public TreeType treeType = TreeType.CART_INDEPENDENT;
071        @Option(longName="print-tree",usage="Prints the decision tree.")
072        public boolean printTree;
073    }
074
075    /**
076     * @param args the command line arguments
077     * @throws IOException if there is any error reading the examples.
078     */
079    public static void main(String[] args) throws IOException {
080
081        //
082        // Use the labs format logging.
083        LabsLogFormatter.setAllLogFormatters();
084
085        DecisionTreeOptions o = new DecisionTreeOptions();
086        ConfigurationManager cm;
087        try {
088            cm = new ConfigurationManager(args,o);
089        } catch (UsageException e) {
090            logger.info(e.getMessage());
091            return;
092        }
093
094        RegressionFactory factory = new RegressionFactory(o.splitChar);
095
096        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
097        Dataset<Regressor> train = data.getA();
098        Dataset<Regressor> test = data.getB();
099
100        RegressorImpurity impurity;
101        switch (o.impurityType) {
102            case MAE:
103                impurity = new MeanAbsoluteError();
104                break;
105            case MSE:
106                impurity = new MeanSquaredError();
107                break;
108            default:
109                logger.severe("unknown impurity type " + o.impurityType);
110                return;
111        }
112
113        if (o.general.trainingPath == null || o.general.testingPath == null) {
114            logger.info(cm.usage());
115            return;
116        }
117
118        SparseTrainer<Regressor> trainer;
119        switch (o.treeType) {
120            case CART_INDEPENDENT:
121                if (o.fraction <= 0) {
122                    trainer = new CARTRegressionTrainer(o.depth,o.minChildWeight,1, impurity, o.general.seed);
123                } else {
124                    trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.fraction, impurity, o.general.seed);
125                }
126                break;
127            case CART_JOINT:
128                if (o.fraction <= 0) {
129                    trainer = new CARTJointRegressionTrainer(o.depth,o.minChildWeight,1, impurity, o.normalize, o.general.seed);
130                } else {
131                    trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.fraction, impurity, o.normalize, o.general.seed);
132                }
133                break;
134            default:
135                logger.severe("unknown tree type " + o.treeType);
136                return;
137        }
138
139        logger.info("Training using " + trainer.toString());
140
141        final long trainStart = System.currentTimeMillis();
142        SparseModel<Regressor> model = trainer.train(train);
143        final long trainStop = System.currentTimeMillis();
144
145        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
146
147        if (o.printTree) {
148            logger.info(model.toString());
149        }
150
151        logger.info("Selected features: " + model.getActiveFeatures());
152        final long testStart = System.currentTimeMillis();
153        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
154        final long testStop = System.currentTimeMillis();
155        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
156        System.out.println(evaluation.toString());
157
158        if (o.general.outputPath != null) {
159            o.general.saveModel(model);
160        }
161    }
162}