001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import org.tribuo.Dataset;
021import org.tribuo.Trainer;
022import org.tribuo.common.tree.AbstractCARTTrainer;
023import org.tribuo.common.tree.AbstractTrainingNode;
024import org.tribuo.provenance.TrainerProvenance;
025import org.tribuo.provenance.impl.TrainerProvenanceImpl;
026import org.tribuo.regression.Regressor;
027import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode;
028import org.tribuo.regression.rtree.impurity.MeanSquaredError;
029import org.tribuo.regression.rtree.impurity.RegressorImpurity;
030
031/**
032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree.
033 * <p>
034 * Builds a single tree for all the regression dimensions.
035 * <p>
036 * See:
037 * <pre>
038 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
039 * "The Elements of Statistical Learning"
040 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
041 * </pre>
042 */
043public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor> {
044
045    /**
046     * Impurity measure used to determine split quality.
047     */
048    @Config(description="The regression impurity to use.")
049    private RegressorImpurity impurity = new MeanSquaredError();
050
051    /**
052     * Normalizes the output of each leaf so it sums to one (i.e., is a probability distribution).
053     */
054    @Config(description="Normalize the output of each leaf so it sums to one.")
055    private boolean normalize = false;
056
057    /**
058     * Creates a CART Trainer.
059     *
060     * @param maxDepth maxDepth The maximum depth of the tree.
061     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
062     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
063     * @param impurity impurity The impurity function to use to determine split quality.
064     * @param normalize Normalize the leaves so each output sums to one.
065     * @param seed The seed to use for the RNG.
066     */
067    public CARTJointRegressionTrainer(
068            int maxDepth,
069            float minChildWeight,
070            float fractionFeaturesInSplit,
071            RegressorImpurity impurity,
072            boolean normalize,
073            long seed
074    ) {
075        super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed);
076        this.impurity = impurity;
077        this.normalize = normalize;
078        postConfig();
079    }
080
081    /**
082     * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError} and does not normalize the outputs.
083     */
084    public CARTJointRegressionTrainer() {
085        this(Integer.MAX_VALUE, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), false, Trainer.DEFAULT_SEED);
086    }
087
088    /**
089     * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError} and does not normalize the outputs.
090     * @param maxDepth The maximum depth of the tree.
091     */
092    public CARTJointRegressionTrainer(int maxDepth) {
093        this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), false, Trainer.DEFAULT_SEED);
094    }
095
096    /**
097     * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError}.
098     * @param maxDepth The maximum depth of the tree.
099     * @param normalize Normalises the leaves so each leaf has a distribution which sums to 1.0.
100     */
101    public CARTJointRegressionTrainer(int maxDepth, boolean normalize) {
102        this(maxDepth, MIN_EXAMPLES, 1.0f, new MeanSquaredError(), normalize, Trainer.DEFAULT_SEED);
103    }
104
105    @Override
106    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples) {
107        return new JointRegressorTrainingNode(impurity, examples, normalize);
108    }
109
110    @Override
111    public String toString() {
112        StringBuilder buffer = new StringBuilder();
113
114        buffer.append("CARTJointRegressionTrainer(maxDepth=");
115        buffer.append(maxDepth);
116        buffer.append(",minChildWeight=");
117        buffer.append(minChildWeight);
118        buffer.append(",fractionFeaturesInSplit=");
119        buffer.append(fractionFeaturesInSplit);
120        buffer.append(",impurity=");
121        buffer.append(impurity.toString());
122        buffer.append(",normalize=");
123        buffer.append(normalize);
124        buffer.append(",seed=");
125        buffer.append(seed);
126        buffer.append(")");
127
128        return buffer.toString();
129    }
130
131    @Override
132    public TrainerProvenance getProvenance() {
133        return new TrainerProvenanceImpl(this);
134    }
135}