001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Trainer;
025import org.tribuo.common.tree.AbstractCARTTrainer;
026import org.tribuo.common.tree.AbstractTrainingNode;
027import org.tribuo.common.tree.Node;
028import org.tribuo.common.tree.TreeModel;
029import org.tribuo.provenance.ModelProvenance;
030import org.tribuo.provenance.TrainerProvenance;
031import org.tribuo.provenance.impl.TrainerProvenanceImpl;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.rtree.impl.RegressorTrainingNode;
034import org.tribuo.regression.rtree.impl.RegressorTrainingNode.InvertedData;
035import org.tribuo.regression.rtree.impurity.MeanSquaredError;
036import org.tribuo.regression.rtree.impurity.RegressorImpurity;
037import org.tribuo.util.Util;
038
039import java.time.OffsetDateTime;
040import java.util.Deque;
041import java.util.HashMap;
042import java.util.LinkedList;
043import java.util.List;
044import java.util.Map;
045import java.util.Set;
046import java.util.SplittableRandom;
047
048/**
049 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree.
050 * Trains an independent tree for each output dimension.
051 * <p>
052 * See:
053 * <pre>
054 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
055 * "The Elements of Statistical Learning"
056 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
057 * </pre>
058 */
059public final class CARTRegressionTrainer extends AbstractCARTTrainer<Regressor> {
060
061    /**
062     * Impurity measure used to determine split quality.
063     */
064    @Config(description="Regression impurity measure used to determine split quality.")
065    private RegressorImpurity impurity = new MeanSquaredError();
066
067    /**
068     * Creates a CART Trainer.
069     *
070     * @param maxDepth maxDepth The maximum depth of the tree.
071     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
072     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
073     * @param impurity impurity The impurity function to use to determine split quality.
074     * @param seed The RNG seed.
075     */
076    public CARTRegressionTrainer(
077            int maxDepth,
078            float minChildWeight,
079            float fractionFeaturesInSplit,
080            RegressorImpurity impurity,
081            long seed
082    ) {
083        super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed);
084        this.impurity = impurity;
085        postConfig();
086    }
087
088    /**
089     * Creates a CART trainer. Sets the impurity to the {@link MeanSquaredError}, uses
090     * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}.
091     */
092    public CARTRegressionTrainer() {
093        this(Integer.MAX_VALUE);
094    }
095
096    /**
097     * Creates a CART trainer. Sets the impurity to the {@link MeanSquaredError}, uses
098     * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}.
099     * @param maxDepth The maximum depth of the tree.
100     */
101    public CARTRegressionTrainer(int maxDepth) {
102        this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), Trainer.DEFAULT_SEED);
103    }
104
105    @Override
106    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples) {
107        throw new IllegalStateException("Shouldn't reach here.");
108    }
109
110    @Override
111    public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
112        if (examples.getOutputInfo().getUnknownCount() > 0) {
113            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
114        }
115        // Creates a new RNG, adds one to the invocation count.
116        SplittableRandom localRNG;
117        TrainerProvenance trainerProvenance;
118        synchronized(this) {
119            localRNG = rng.split();
120            trainerProvenance = getProvenance();
121            trainInvocationCounter++;
122        }
123
124        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
125        ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo();
126        Set<Regressor> domain = outputIDInfo.getDomain();
127
128        int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size());
129        int[] indices;
130        int[] originalIndices = new int[featureIDMap.size()];
131        for (int i = 0; i < originalIndices.length; i++) {
132            originalIndices[i] = i;
133        }
134        if (numFeaturesInSplit != featureIDMap.size()) {
135            indices = new int[numFeaturesInSplit];
136            // log
137        } else {
138            indices = originalIndices;
139        }
140
141        InvertedData data = RegressorTrainingNode.invertData(examples);
142
143        Map<String, Node<Regressor>> nodeMap = new HashMap<>();
144        for (Regressor r : domain) {
145            String dimName = r.getNames()[0];
146            int dimIdx = outputIDInfo.getID(r);
147
148            AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity,data,dimIdx,dimName,examples.size(),featureIDMap,outputIDInfo);
149            Deque<AbstractTrainingNode<Regressor>> queue = new LinkedList<>();
150            queue.add(root);
151
152            while (!queue.isEmpty()) {
153                AbstractTrainingNode<Regressor> node = queue.poll();
154                if ((node.getDepth() < maxDepth) &&
155                        (node.getNumExamples() > minChildWeight)) {
156                    if (numFeaturesInSplit != featureIDMap.size()) {
157                        Util.randpermInPlace(originalIndices, localRNG);
158                        System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
159                    }
160                    List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices);
161                    // Use the queue as a stack to improve cache locality.
162                    for (AbstractTrainingNode<Regressor> newNode : nodes) {
163                        queue.addFirst(newNode);
164                    }
165                }
166            }
167
168            nodeMap.put(dimName,root.convertTree());
169        }
170
171        ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
172        return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap);
173    }
174
175    @Override
176    public String toString() {
177        StringBuilder buffer = new StringBuilder();
178
179        buffer.append("CARTRegressionTrainer(maxDepth=");
180        buffer.append(maxDepth);
181        buffer.append(",minChildWeight=");
182        buffer.append(minChildWeight);
183        buffer.append(",fractionFeaturesInSplit=");
184        buffer.append(fractionFeaturesInSplit);
185        buffer.append(",impurity=");
186        buffer.append(impurity.toString());
187        buffer.append(",seed=");
188        buffer.append(seed);
189        buffer.append(")");
190
191        return buffer.toString();
192    }
193
194    @Override
195    public TrainerProvenance getProvenance() {
196        return new TrainerProvenanceImpl(this);
197    }
198}