001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.dtree; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import org.tribuo.Trainer; 021import org.tribuo.classification.ClassificationOptions; 022import org.tribuo.classification.dtree.impurity.Entropy; 023import org.tribuo.classification.dtree.impurity.GiniIndex; 024import org.tribuo.classification.dtree.impurity.LabelImpurity; 025 026/** 027 * Options for building a classification tree trainer. 028 */ 029public class CARTClassificationOptions implements ClassificationOptions<CARTClassificationTrainer> { 030 031 @Override 032 public String getOptionsDescription() { 033 return "Options for decision/classification trees."; 034 } 035 036 public enum TreeType {CART} 037 038 public enum ImpurityType {GINI, ENTROPY} 039 040 @Option(longName = "cart-max-depth", usage = "Maximum depth in the decision tree.") 041 public int cartMaxDepth = 6; 042 @Option(longName = "cart-min-child-weight", usage = "Minimum child weight.") 043 public float cartMinChildWeight = 5.0f; 044 @Option(longName = "cart-split-fraction", usage = "Fraction of features in split.") 045 public float cartSplitFraction = 0.0f; 046 @Option(longName = "cart-impurity", usage = "Impurity measure to use. Defaults to GINI.") 047 public ImpurityType cartImpurity = ImpurityType.GINI; 048 @Option(longName = "cart-print-tree", usage = "Prints the decision tree.") 049 public boolean cartPrintTree; 050 @Option(longName = "cart-tree-algorithm", usage = "Tree algorithm to use (options are CART).") 051 public TreeType cartTreeAlgorithm = TreeType.CART; 052 @Option(longName = "cart-seed", usage = "RNG seed.") 053 public long cartSeed = Trainer.DEFAULT_SEED; 054 055 @Override 056 public CARTClassificationTrainer getTrainer() { 057 LabelImpurity impurity; 058 switch (cartImpurity) { 059 case GINI: 060 impurity = new GiniIndex(); 061 break; 062 case ENTROPY: 063 impurity = new Entropy(); 064 break; 065 default: 066 throw new IllegalArgumentException("unknown impurity type " + cartImpurity); 067 } 068 069 CARTClassificationTrainer trainer; 070 switch (cartTreeAlgorithm) { 071 case CART: 072 if (cartSplitFraction <= 0) { 073 trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, 1, impurity, cartSeed); 074 } else { 075 trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, cartSplitFraction, impurity, cartSeed); 076 } 077 break; 078 default: 079 throw new IllegalArgumentException("Unknown tree type " + cartTreeAlgorithm); 080 } 081 082 return trainer; 083 } 084 085}