001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Trainer; 025import org.tribuo.common.tree.AbstractCARTTrainer; 026import org.tribuo.common.tree.AbstractTrainingNode; 027import org.tribuo.common.tree.Node; 028import org.tribuo.common.tree.TreeModel; 029import org.tribuo.provenance.ModelProvenance; 030import org.tribuo.provenance.TrainerProvenance; 031import org.tribuo.provenance.impl.TrainerProvenanceImpl; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.rtree.impl.RegressorTrainingNode; 034import org.tribuo.regression.rtree.impl.RegressorTrainingNode.InvertedData; 035import org.tribuo.regression.rtree.impurity.MeanSquaredError; 036import org.tribuo.regression.rtree.impurity.RegressorImpurity; 037import org.tribuo.util.Util; 038 039import java.time.OffsetDateTime; 040import java.util.Deque; 041import java.util.HashMap; 042import java.util.LinkedList; 043import java.util.List; 044import java.util.Map; 045import java.util.Set; 046import java.util.SplittableRandom; 047 048/** 049 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree. 050 * Trains an independent tree for each output dimension. 051 * <p> 052 * See: 053 * <pre> 054 * J. Friedman, T. Hastie, & R. Tibshirani. 055 * "The Elements of Statistical Learning" 056 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 057 * </pre> 058 */ 059public final class CARTRegressionTrainer extends AbstractCARTTrainer<Regressor> { 060 061 /** 062 * Impurity measure used to determine split quality. 063 */ 064 @Config(description="Regression impurity measure used to determine split quality.") 065 private RegressorImpurity impurity = new MeanSquaredError(); 066 067 /** 068 * Creates a CART Trainer. 069 * 070 * @param maxDepth maxDepth The maximum depth of the tree. 071 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 072 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 073 * @param impurity impurity The impurity function to use to determine split quality. 074 * @param seed The RNG seed. 075 */ 076 public CARTRegressionTrainer( 077 int maxDepth, 078 float minChildWeight, 079 float fractionFeaturesInSplit, 080 RegressorImpurity impurity, 081 long seed 082 ) { 083 super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed); 084 this.impurity = impurity; 085 postConfig(); 086 } 087 088 /** 089 * Creates a CART trainer. Sets the impurity to the {@link MeanSquaredError}, uses 090 * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}. 091 */ 092 public CARTRegressionTrainer() { 093 this(Integer.MAX_VALUE); 094 } 095 096 /** 097 * Creates a CART trainer. Sets the impurity to the {@link MeanSquaredError}, uses 098 * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}. 099 * @param maxDepth The maximum depth of the tree. 100 */ 101 public CARTRegressionTrainer(int maxDepth) { 102 this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), Trainer.DEFAULT_SEED); 103 } 104 105 @Override 106 protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples) { 107 throw new IllegalStateException("Shouldn't reach here."); 108 } 109 110 @Override 111 public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 112 if (examples.getOutputInfo().getUnknownCount() > 0) { 113 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 114 } 115 // Creates a new RNG, adds one to the invocation count. 116 SplittableRandom localRNG; 117 TrainerProvenance trainerProvenance; 118 synchronized(this) { 119 localRNG = rng.split(); 120 trainerProvenance = getProvenance(); 121 trainInvocationCounter++; 122 } 123 124 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 125 ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo(); 126 Set<Regressor> domain = outputIDInfo.getDomain(); 127 128 int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size()); 129 int[] indices; 130 int[] originalIndices = new int[featureIDMap.size()]; 131 for (int i = 0; i < originalIndices.length; i++) { 132 originalIndices[i] = i; 133 } 134 if (numFeaturesInSplit != featureIDMap.size()) { 135 indices = new int[numFeaturesInSplit]; 136 // log 137 } else { 138 indices = originalIndices; 139 } 140 141 InvertedData data = RegressorTrainingNode.invertData(examples); 142 143 Map<String, Node<Regressor>> nodeMap = new HashMap<>(); 144 for (Regressor r : domain) { 145 String dimName = r.getNames()[0]; 146 int dimIdx = outputIDInfo.getID(r); 147 148 AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity,data,dimIdx,dimName,examples.size(),featureIDMap,outputIDInfo); 149 Deque<AbstractTrainingNode<Regressor>> queue = new LinkedList<>(); 150 queue.add(root); 151 152 while (!queue.isEmpty()) { 153 AbstractTrainingNode<Regressor> node = queue.poll(); 154 if ((node.getDepth() < maxDepth) && 155 (node.getNumExamples() > minChildWeight)) { 156 if (numFeaturesInSplit != featureIDMap.size()) { 157 Util.randpermInPlace(originalIndices, localRNG); 158 System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit); 159 } 160 List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices); 161 // Use the queue as a stack to improve cache locality. 162 for (AbstractTrainingNode<Regressor> newNode : nodes) { 163 queue.addFirst(newNode); 164 } 165 } 166 } 167 168 nodeMap.put(dimName,root.convertTree()); 169 } 170 171 ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 172 return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap); 173 } 174 175 @Override 176 public String toString() { 177 StringBuilder buffer = new StringBuilder(); 178 179 buffer.append("CARTRegressionTrainer(maxDepth="); 180 buffer.append(maxDepth); 181 buffer.append(",minChildWeight="); 182 buffer.append(minChildWeight); 183 buffer.append(",fractionFeaturesInSplit="); 184 buffer.append(fractionFeaturesInSplit); 185 buffer.append(",impurity="); 186 buffer.append(impurity.toString()); 187 buffer.append(",seed="); 188 buffer.append(seed); 189 buffer.append(")"); 190 191 return buffer.toString(); 192 } 193 194 @Override 195 public TrainerProvenance getProvenance() { 196 return new TrainerProvenanceImpl(this); 197 } 198}