001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import org.tribuo.Dataset; 021import org.tribuo.Trainer; 022import org.tribuo.common.tree.AbstractCARTTrainer; 023import org.tribuo.common.tree.AbstractTrainingNode; 024import org.tribuo.provenance.TrainerProvenance; 025import org.tribuo.provenance.impl.TrainerProvenanceImpl; 026import org.tribuo.regression.Regressor; 027import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode; 028import org.tribuo.regression.rtree.impurity.MeanSquaredError; 029import org.tribuo.regression.rtree.impurity.RegressorImpurity; 030 031/** 032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree. 033 * <p> 034 * Builds a single tree for all the regression dimensions. 035 * <p> 036 * See: 037 * <pre> 038 * J. Friedman, T. Hastie, & R. Tibshirani. 039 * "The Elements of Statistical Learning" 040 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 041 * </pre> 042 */ 043public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor> { 044 045 /** 046 * Impurity measure used to determine split quality. 047 */ 048 @Config(description="The regression impurity to use.") 049 private RegressorImpurity impurity = new MeanSquaredError(); 050 051 /** 052 * Normalizes the output of each leaf so it sums to one (i.e., is a probability distribution). 053 */ 054 @Config(description="Normalize the output of each leaf so it sums to one.") 055 private boolean normalize = false; 056 057 /** 058 * Creates a CART Trainer. 059 * 060 * @param maxDepth maxDepth The maximum depth of the tree. 061 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 062 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 063 * @param impurity impurity The impurity function to use to determine split quality. 064 * @param normalize Normalize the leaves so each output sums to one. 065 * @param seed The seed to use for the RNG. 066 */ 067 public CARTJointRegressionTrainer( 068 int maxDepth, 069 float minChildWeight, 070 float fractionFeaturesInSplit, 071 RegressorImpurity impurity, 072 boolean normalize, 073 long seed 074 ) { 075 super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed); 076 this.impurity = impurity; 077 this.normalize = normalize; 078 postConfig(); 079 } 080 081 /** 082 * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError} and does not normalize the outputs. 083 */ 084 public CARTJointRegressionTrainer() { 085 this(Integer.MAX_VALUE, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), false, Trainer.DEFAULT_SEED); 086 } 087 088 /** 089 * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError} and does not normalize the outputs. 090 * @param maxDepth The maximum depth of the tree. 091 */ 092 public CARTJointRegressionTrainer(int maxDepth) { 093 this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), false, Trainer.DEFAULT_SEED); 094 } 095 096 /** 097 * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError}. 098 * @param maxDepth The maximum depth of the tree. 099 * @param normalize Normalises the leaves so each leaf has a distribution which sums to 1.0. 100 */ 101 public CARTJointRegressionTrainer(int maxDepth, boolean normalize) { 102 this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), normalize, Trainer.DEFAULT_SEED); 103 } 104 105 @Override 106 protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples) { 107 return new JointRegressorTrainingNode(impurity, examples, normalize); 108 } 109 110 @Override 111 public String toString() { 112 StringBuilder buffer = new StringBuilder(); 113 114 buffer.append("CARTJointRegressionTrainer(maxDepth="); 115 buffer.append(maxDepth); 116 buffer.append(",minChildWeight="); 117 buffer.append(minChildWeight); 118 buffer.append(",fractionFeaturesInSplit="); 119 buffer.append(fractionFeaturesInSplit); 120 buffer.append(",impurity="); 121 buffer.append(impurity.toString()); 122 buffer.append(",normalize="); 123 buffer.append(normalize); 124 buffer.append(",seed="); 125 buffer.append(seed); 126 buffer.append(")"); 127 128 return buffer.toString(); 129 } 130 131 @Override 132 public TrainerProvenance getProvenance() { 133 return new TrainerProvenanceImpl(this); 134 } 135}