001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.dtree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import org.tribuo.Dataset; 021import org.tribuo.Trainer; 022import org.tribuo.classification.Label; 023import org.tribuo.classification.dtree.impl.ClassifierTrainingNode; 024import org.tribuo.classification.dtree.impurity.GiniIndex; 025import org.tribuo.classification.dtree.impurity.LabelImpurity; 026import org.tribuo.common.tree.AbstractCARTTrainer; 027import org.tribuo.common.tree.AbstractTrainingNode; 028import org.tribuo.provenance.TrainerProvenance; 029import org.tribuo.provenance.impl.TrainerProvenanceImpl; 030 031/** 032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree. 033 * <p> 034 * See: 035 * <pre> 036 * J. Friedman, T. Hastie, & R. Tibshirani. 037 * "The Elements of Statistical Learning" 038 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 039 * </pre> 040 */ 041public class CARTClassificationTrainer extends AbstractCARTTrainer<Label> { 042 043 /** 044 * Impurity measure used to determine split quality. 045 */ 046 @Config(description = "The impurity measure used to determine split quality.") 047 private LabelImpurity impurity = new GiniIndex(); 048 049 /** 050 * Creates a CART Trainer. 051 * 052 * @param maxDepth The maximum depth of the tree. 053 * @param minChildWeight The minimum node weight to consider it for a split. 054 * @param fractionFeaturesInSplit The fraction of features available in each split. 055 * @param impurity Impurity measure to determine split quality. See {@link LabelImpurity}. 056 * @param seed The RNG seed. 057 */ 058 public CARTClassificationTrainer( 059 int maxDepth, 060 float minChildWeight, 061 float fractionFeaturesInSplit, 062 LabelImpurity impurity, 063 long seed 064 ) { 065 super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed); 066 this.impurity = impurity; 067 postConfig(); 068 } 069 070 /** 071 * Creates a CART Trainer. Sets the impurity to the {@link GiniIndex}. 072 */ 073 public CARTClassificationTrainer() { 074 this(Integer.MAX_VALUE); 075 } 076 077 /** 078 * Creates a CART trainer. Sets the impurity to the {@link GiniIndex}, uses 079 * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}. 080 * @param maxDepth The maximum depth of the tree. 081 */ 082 public CARTClassificationTrainer(int maxDepth) { 083 this(maxDepth, MIN_EXAMPLES, 1.0f, new GiniIndex(), Trainer.DEFAULT_SEED); 084 } 085 086 /** 087 * Creates a CART Trainer. Sets the impurity to the {@link GiniIndex}. 088 * @param maxDepth The maximum depth of the tree. 089 * @param fractionFeaturesInSplit The fraction of features available in each split. 090 * @param seed The seed for the RNG. 091 */ 092 public CARTClassificationTrainer(int maxDepth, float fractionFeaturesInSplit, long seed) { 093 this(maxDepth, MIN_EXAMPLES, fractionFeaturesInSplit, new GiniIndex(), seed); 094 } 095 096 @Override 097 protected AbstractTrainingNode<Label> mkTrainingNode(Dataset<Label> examples) { 098 return new ClassifierTrainingNode(impurity, examples); 099 } 100 101 @Override 102 public String toString() { 103 StringBuilder buffer = new StringBuilder(); 104 105 buffer.append("CARTClassificationTrainer(maxDepth="); 106 buffer.append(maxDepth); 107 buffer.append(",minChildWeight="); 108 buffer.append(minChildWeight); 109 buffer.append(",fractionFeaturesInSplit="); 110 buffer.append(fractionFeaturesInSplit); 111 buffer.append(",impurity="); 112 buffer.append(impurity.toString()); 113 buffer.append(",seed="); 114 buffer.append(seed); 115 buffer.append(")"); 116 117 return buffer.toString(); 118 } 119 120 @Override 121 public TrainerProvenance getProvenance() { 122 return new TrainerProvenanceImpl(this); 123 } 124}