001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Output;
025import org.tribuo.Trainer;
026import org.tribuo.provenance.ModelProvenance;
027import org.tribuo.provenance.SkeletalTrainerProvenance;
028import org.tribuo.provenance.TrainerProvenance;
029import org.tribuo.util.Util;
030
031import java.time.OffsetDateTime;
032import java.util.Collections;
033import java.util.Deque;
034import java.util.LinkedList;
035import java.util.List;
036import java.util.Map;
037import java.util.SplittableRandom;
038
039/**
040 * Base class for {@link org.tribuo.Trainer}'s that use an approximation of the CART algorithm to build a decision tree.
041 * <p>
042 * See:
043 * <pre>
044 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
045 * "The Elements of Statistical Learning"
046 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
047 * </pre>
048 */
049public abstract class AbstractCARTTrainer<T extends Output<T>> implements DecisionTreeTrainer<T> {
050
051    /**
052     * Default minimum weight of examples allowed in a leaf node.
053     */
054    public static final int MIN_EXAMPLES = 5;
055
056    /**
057     * Minimum weight of examples allowed in a leaf.
058     */
059    @Config(description="The minimum weight allowed in a child node.")
060    protected float minChildWeight = MIN_EXAMPLES;
061
062    /**
063     * Maximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited.
064     */
065    @Config(description="The maximum depth of the tree.")
066    protected int maxDepth = Integer.MAX_VALUE;
067
068    /**
069     * Number of features to sample per split. 1 indicates all features are considered.
070     */
071    @Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.")
072    protected float fractionFeaturesInSplit = 1.0f;
073
074    @Config(description="The RNG seed to use when sampling features in a split.")
075    protected long seed = Trainer.DEFAULT_SEED;
076
077    protected SplittableRandom rng;
078
079    protected int trainInvocationCounter;
080
081    /**
082     * After calls to this superconstructor subclasses must call postConfig().
083     * @param maxDepth The maximum depth of the tree.
084     * @param minChildWeight The minimum child weight allowed.
085     * @param fractionFeaturesInSplit The fraction of features to consider at each split.
086     * @param seed The seed for the feature subsampling RNG.
087     */
088    protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed) {
089        this.maxDepth = maxDepth;
090        this.fractionFeaturesInSplit = fractionFeaturesInSplit;
091        this.minChildWeight = minChildWeight;
092        this.seed = seed;
093    }
094
095    @Override
096    public synchronized void postConfig() {
097        this.rng = new SplittableRandom(seed);
098    }
099
100    @Override
101    public int getInvocationCount() {
102        return trainInvocationCounter;
103    }
104
105    @Override
106    public float getFractionFeaturesInSplit() {
107        return fractionFeaturesInSplit;
108    }
109
110    @Override
111    public TreeModel<T> train(Dataset<T> examples) {
112        return train(examples, Collections.emptyMap());
113    }
114
115    @Override
116    public TreeModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
117        if (examples.getOutputInfo().getUnknownCount() > 0) {
118            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
119        }
120        // Creates a new RNG, adds one to the invocation count.
121        SplittableRandom localRNG;
122        TrainerProvenance trainerProvenance;
123        synchronized(this) {
124            localRNG = rng.split();
125            trainerProvenance = getProvenance();
126            trainInvocationCounter++;
127        }
128
129        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
130        ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo();
131
132        int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size());
133        int[] indices;
134        int[] originalIndices = new int[featureIDMap.size()];
135        for (int i = 0; i < originalIndices.length; i++) {
136            originalIndices[i] = i;
137        }
138        if (numFeaturesInSplit != featureIDMap.size()) {
139            indices = new int[numFeaturesInSplit];
140            // log
141        } else {
142            indices = originalIndices;
143        }
144
145        AbstractTrainingNode<T> root = mkTrainingNode(examples);
146        Deque<AbstractTrainingNode<T>> queue = new LinkedList<>();
147        queue.add(root);
148
149        while (!queue.isEmpty()) {
150            AbstractTrainingNode<T> node = queue.poll();
151            if ((node.getDepth() < maxDepth) &&
152                    (node.getNumExamples() > minChildWeight)) {
153                if (numFeaturesInSplit != featureIDMap.size()) {
154                    Util.randpermInPlace(originalIndices, localRNG);
155                    System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
156                }
157                List<AbstractTrainingNode<T>> nodes = node.buildTree(indices);
158                // Use the queue as a stack to improve cache locality.
159                // Building depth first.
160                for (AbstractTrainingNode<T> newNode : nodes) {
161                    queue.addFirst(newNode);
162                }
163            }
164        }
165
166        ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
167        return new TreeModel<>("cart-tree", provenance, featureIDMap, outputIDInfo, false, root.convertTree());
168    }
169
170    protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples);
171
172    /**
173     * Provenance for {@link AbstractCARTTrainer}. No longer used.
174     */
175    @Deprecated
176    protected static abstract class AbstractCARTTrainerProvenance extends SkeletalTrainerProvenance {
177        private static final long serialVersionUID = 1L;
178
179        protected <T extends Output<T>> AbstractCARTTrainerProvenance(AbstractCARTTrainer<T> host) {
180            super(host);
181        }
182
183        protected AbstractCARTTrainerProvenance(Map<String,Provenance> map) {
184            super(map);
185        }
186    }
187
188}