001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import com.oracle.labs.mlrg.olcut.config.PropertyException;
020import org.tribuo.Output;
021import org.tribuo.ensemble.BaggingTrainer;
022import org.tribuo.ensemble.EnsembleCombiner;
023
024import java.util.logging.Logger;
025
026/**
027 * A trainer which produces a random forest.
028 * <p>
029 * Random Forests are basically bagged trees, with feature subsampling at each of the nodes.
030 * It's up to the user to supply a decision tree trainer which has feature subsampling turned on by
031 * checking {@link DecisionTreeTrainer#getFractionFeaturesInSplit()}.
032 * <p>
033 * See:
034 * <pre>
035 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
036 * "The Elements of Statistical Learning"
037 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
038 * </pre>
039 */
040public class RandomForestTrainer<T extends Output<T>> extends BaggingTrainer<T> {
041
042    private static final Logger logger = Logger.getLogger(RandomForestTrainer.class.getName());
043
044    /**
045     * For the configuration system.
046     */
047    private RandomForestTrainer() { }
048
049    /**
050     * Constructs a RandomForestTrainer with the default seed {@link org.tribuo.Trainer#DEFAULT_SEED}.
051     * <p>
052     * Throws {@link PropertyException} if the trainer is not set to subsample the features.
053     * @param trainer The tree trainer.
054     * @param combiner The combining function for the ensemble.
055     * @param numMembers The number of ensemble members to train.
056     */
057    public RandomForestTrainer(DecisionTreeTrainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) {
058        super(trainer,combiner,numMembers);
059    }
060
061    /**
062     * Constructs a RandomForestTrainer with the supplied seed, trainer, combining function and number of members.
063     * <p>
064     * Throws {@link PropertyException} if the trainer is not set to subsample the features.
065     * @param trainer The tree trainer.
066     * @param combiner The combining function for the ensemble.
067     * @param numMembers The number of ensemble members to train.
068     * @param seed The RNG seed.
069     */
070    public RandomForestTrainer(DecisionTreeTrainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) {
071        super(trainer,combiner,numMembers,seed);
072    }
073
074    /**
075     * Used by the OLCUT configuration system, and should not be called by external code.
076     */
077    @Override
078    public void postConfig() {
079        super.postConfig();
080        if (!(innerTrainer instanceof DecisionTreeTrainer)) {
081            throw new PropertyException("","innerTrainer","RandomForestTrainer requires a decision tree innerTrainer");
082        }
083        DecisionTreeTrainer<T> t = (DecisionTreeTrainer<T>) innerTrainer;
084        if (t.getFractionFeaturesInSplit() == 1f) {
085            throw new PropertyException("","innerTrainer","RandomForestTrainer requires that the decision tree " +
086                    "innerTrainer have fractional features in split.");
087        }
088    }
089
090    @Override
091    protected String ensembleName() {
092        return "random-forest-ensemble";
093    }
094
095    @Override
096    public String toString() {
097        StringBuilder buffer = new StringBuilder();
098
099        buffer.append("RandomForestTrainer(");
100        buffer.append("innerTrainer=");
101        buffer.append(innerTrainer.toString());
102        buffer.append(",combiner=");
103        buffer.append(combiner.toString());
104        buffer.append(",numMembers=");
105        buffer.append(numMembers);
106        buffer.append(",seed=");
107        buffer.append(seed);
108        buffer.append(")");
109
110        return buffer.toString();
111    }
112    
113}