001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.dtree;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import org.tribuo.Trainer;
021import org.tribuo.classification.ClassificationOptions;
022import org.tribuo.classification.dtree.impurity.Entropy;
023import org.tribuo.classification.dtree.impurity.GiniIndex;
024import org.tribuo.classification.dtree.impurity.LabelImpurity;
025
026/**
027 * Options for building a classification tree trainer.
028 */
029public class CARTClassificationOptions implements ClassificationOptions<CARTClassificationTrainer> {
030
031    @Override
032    public String getOptionsDescription() {
033        return "Options for decision/classification trees.";
034    }
035
036    public enum TreeType {CART}
037
038    public enum ImpurityType {GINI, ENTROPY}
039
040    @Option(longName = "cart-max-depth", usage = "Maximum depth in the decision tree.")
041    public int cartMaxDepth = 6;
042    @Option(longName = "cart-min-child-weight", usage = "Minimum child weight.")
043    public float cartMinChildWeight = 5.0f;
044    @Option(longName = "cart-split-fraction", usage = "Fraction of features in split.")
045    public float cartSplitFraction = 0.0f;
046    @Option(longName = "cart-impurity", usage = "Impurity measure to use. Defaults to GINI.")
047    public ImpurityType cartImpurity = ImpurityType.GINI;
048    @Option(longName = "cart-print-tree", usage = "Prints the decision tree.")
049    public boolean cartPrintTree;
050    @Option(longName = "cart-tree-algorithm", usage = "Tree algorithm to use (options are CART).")
051    public TreeType cartTreeAlgorithm = TreeType.CART;
052    @Option(longName = "cart-seed", usage = "RNG seed.")
053    public long cartSeed = Trainer.DEFAULT_SEED;
054
055    @Override
056    public CARTClassificationTrainer getTrainer() {
057        LabelImpurity impurity;
058        switch (cartImpurity) {
059            case GINI:
060                impurity = new GiniIndex();
061                break;
062            case ENTROPY:
063                impurity = new Entropy();
064                break;
065            default:
066                throw new IllegalArgumentException("unknown impurity type " + cartImpurity);
067        }
068
069        CARTClassificationTrainer trainer;
070        switch (cartTreeAlgorithm) {
071            case CART:
072                if (cartSplitFraction <= 0) {
073                    trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, 1, impurity, cartSeed);
074                } else {
075                    trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, cartSplitFraction, impurity, cartSeed);
076                }
077                break;
078            default:
079                throw new IllegalArgumentException("Unknown tree type " + cartTreeAlgorithm);
080        }
081
082        return trainer;
083    }
084
085}