001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.tree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Output; 025import org.tribuo.Trainer; 026import org.tribuo.provenance.ModelProvenance; 027import org.tribuo.provenance.SkeletalTrainerProvenance; 028import org.tribuo.provenance.TrainerProvenance; 029import org.tribuo.util.Util; 030 031import java.time.OffsetDateTime; 032import java.util.Collections; 033import java.util.Deque; 034import java.util.LinkedList; 035import java.util.List; 036import java.util.Map; 037import java.util.SplittableRandom; 038 039/** 040 * Base class for {@link org.tribuo.Trainer}'s that use an approximation of the CART algorithm to build a decision tree. 041 * <p> 042 * See: 043 * <pre> 044 * J. Friedman, T. Hastie, & R. Tibshirani. 045 * "The Elements of Statistical Learning" 046 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 047 * </pre> 048 */ 049public abstract class AbstractCARTTrainer<T extends Output<T>> implements DecisionTreeTrainer<T> { 050 051 /** 052 * Default minimum weight of examples allowed in a leaf node. 053 */ 054 public static final int MIN_EXAMPLES = 5; 055 056 /** 057 * Minimum weight of examples allowed in a leaf. 058 */ 059 @Config(description="The minimum weight allowed in a child node.") 060 protected float minChildWeight = MIN_EXAMPLES; 061 062 /** 063 * Maximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited. 064 */ 065 @Config(description="The maximum depth of the tree.") 066 protected int maxDepth = Integer.MAX_VALUE; 067 068 /** 069 * Number of features to sample per split. 1 indicates all features are considered. 070 */ 071 @Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") 072 protected float fractionFeaturesInSplit = 1.0f; 073 074 @Config(description="The RNG seed to use when sampling features in a split.") 075 protected long seed = Trainer.DEFAULT_SEED; 076 077 protected SplittableRandom rng; 078 079 protected int trainInvocationCounter; 080 081 /** 082 * After calls to this superconstructor subclasses must call postConfig(). 083 * @param maxDepth The maximum depth of the tree. 084 * @param minChildWeight The minimum child weight allowed. 085 * @param fractionFeaturesInSplit The fraction of features to consider at each split. 086 * @param seed The seed for the feature subsampling RNG. 087 */ 088 protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed) { 089 this.maxDepth = maxDepth; 090 this.fractionFeaturesInSplit = fractionFeaturesInSplit; 091 this.minChildWeight = minChildWeight; 092 this.seed = seed; 093 } 094 095 @Override 096 public synchronized void postConfig() { 097 this.rng = new SplittableRandom(seed); 098 } 099 100 @Override 101 public int getInvocationCount() { 102 return trainInvocationCounter; 103 } 104 105 @Override 106 public float getFractionFeaturesInSplit() { 107 return fractionFeaturesInSplit; 108 } 109 110 @Override 111 public TreeModel<T> train(Dataset<T> examples) { 112 return train(examples, Collections.emptyMap()); 113 } 114 115 @Override 116 public TreeModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 117 if (examples.getOutputInfo().getUnknownCount() > 0) { 118 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 119 } 120 // Creates a new RNG, adds one to the invocation count. 121 SplittableRandom localRNG; 122 TrainerProvenance trainerProvenance; 123 synchronized(this) { 124 localRNG = rng.split(); 125 trainerProvenance = getProvenance(); 126 trainInvocationCounter++; 127 } 128 129 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 130 ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo(); 131 132 int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size()); 133 int[] indices; 134 int[] originalIndices = new int[featureIDMap.size()]; 135 for (int i = 0; i < originalIndices.length; i++) { 136 originalIndices[i] = i; 137 } 138 if (numFeaturesInSplit != featureIDMap.size()) { 139 indices = new int[numFeaturesInSplit]; 140 // log 141 } else { 142 indices = originalIndices; 143 } 144 145 AbstractTrainingNode<T> root = mkTrainingNode(examples); 146 Deque<AbstractTrainingNode<T>> queue = new LinkedList<>(); 147 queue.add(root); 148 149 while (!queue.isEmpty()) { 150 AbstractTrainingNode<T> node = queue.poll(); 151 if ((node.getDepth() < maxDepth) && 152 (node.getNumExamples() > minChildWeight)) { 153 if (numFeaturesInSplit != featureIDMap.size()) { 154 Util.randpermInPlace(originalIndices, localRNG); 155 System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit); 156 } 157 List<AbstractTrainingNode<T>> nodes = node.buildTree(indices); 158 // Use the queue as a stack to improve cache locality. 159 // Building depth first. 160 for (AbstractTrainingNode<T> newNode : nodes) { 161 queue.addFirst(newNode); 162 } 163 } 164 } 165 166 ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 167 return new TreeModel<>("cart-tree", provenance, featureIDMap, outputIDInfo, false, root.convertTree()); 168 } 169 170 protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples); 171 172 /** 173 * Provenance for {@link AbstractCARTTrainer}. No longer used. 174 */ 175 @Deprecated 176 protected static abstract class AbstractCARTTrainerProvenance extends SkeletalTrainerProvenance { 177 private static final long serialVersionUID = 1L; 178 179 protected <T extends Output<T>> AbstractCARTTrainerProvenance(AbstractCARTTrainer<T> host) { 180 super(host); 181 } 182 183 protected AbstractCARTTrainerProvenance(Map<String,Provenance> map) { 184 super(map); 185 } 186 } 187 188}