001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.dtree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import org.tribuo.Dataset;
021import org.tribuo.Trainer;
022import org.tribuo.classification.Label;
023import org.tribuo.classification.dtree.impl.ClassifierTrainingNode;
024import org.tribuo.classification.dtree.impurity.GiniIndex;
025import org.tribuo.classification.dtree.impurity.LabelImpurity;
026import org.tribuo.common.tree.AbstractCARTTrainer;
027import org.tribuo.common.tree.AbstractTrainingNode;
028import org.tribuo.provenance.TrainerProvenance;
029import org.tribuo.provenance.impl.TrainerProvenanceImpl;
030
031/**
032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree.
033 * <p>
034 * See:
035 * <pre>
036 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
037 * "The Elements of Statistical Learning"
038 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
039 * </pre>
040 */
041public class CARTClassificationTrainer extends AbstractCARTTrainer<Label> {
042
043    /**
044     * Impurity measure used to determine split quality.
045     */
046    @Config(description = "The impurity measure used to determine split quality.")
047    private LabelImpurity impurity = new GiniIndex();
048
049    /**
050     * Creates a CART Trainer.
051     *
052     * @param maxDepth The maximum depth of the tree.
053     * @param minChildWeight The minimum node weight to consider it for a split.
054     * @param fractionFeaturesInSplit The fraction of features available in each split.
055     * @param impurity Impurity measure to determine split quality. See {@link LabelImpurity}.
056     * @param seed The RNG seed.
057     */
058    public CARTClassificationTrainer(
059            int maxDepth,
060            float minChildWeight,
061            float fractionFeaturesInSplit,
062            LabelImpurity impurity,
063            long seed
064    ) {
065        super(maxDepth, minChildWeight, fractionFeaturesInSplit, seed);
066        this.impurity = impurity;
067        postConfig();
068    }
069
070    /**
071     * Creates a CART Trainer. Sets the impurity to the {@link GiniIndex}.
072     */
073    public CARTClassificationTrainer() {
074        this(Integer.MAX_VALUE);
075    }
076
077    /**
078     * Creates a CART trainer. Sets the impurity to the {@link GiniIndex}, uses
079     * all the features, and sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}.
080     * @param maxDepth The maximum depth of the tree.
081     */
082    public CARTClassificationTrainer(int maxDepth) {
083        this(maxDepth, MIN_EXAMPLES, 1.0f, new GiniIndex(), Trainer.DEFAULT_SEED);
084    }
085
086    /**
087     * Creates a CART Trainer. Sets the impurity to the {@link GiniIndex}.
088     * @param maxDepth The maximum depth of the tree.
089     * @param fractionFeaturesInSplit The fraction of features available in each split.
090     * @param seed The seed for the RNG.
091     */
092    public CARTClassificationTrainer(int maxDepth, float fractionFeaturesInSplit, long seed) {
093        this(maxDepth, MIN_EXAMPLES, fractionFeaturesInSplit, new GiniIndex(), seed);
094    }
095
096    @Override
097    protected AbstractTrainingNode<Label> mkTrainingNode(Dataset<Label> examples) {
098        return new ClassifierTrainingNode(impurity, examples);
099    }
100
101    @Override
102    public String toString() {
103        StringBuilder buffer = new StringBuilder();
104
105        buffer.append("CARTClassificationTrainer(maxDepth=");
106        buffer.append(maxDepth);
107        buffer.append(",minChildWeight=");
108        buffer.append(minChildWeight);
109        buffer.append(",fractionFeaturesInSplit=");
110        buffer.append(fractionFeaturesInSplit);
111        buffer.append(",impurity=");
112        buffer.append(impurity.toString());
113        buffer.append(",seed=");
114        buffer.append(seed);
115        buffer.append(")");
116
117        return buffer.toString();
118    }
119
120    @Override
121    public TrainerProvenance getProvenance() {
122        return new TrainerProvenanceImpl(this);
123    }
124}