001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.tree; 018 019import com.oracle.labs.mlrg.olcut.config.PropertyException; 020import org.tribuo.Output; 021import org.tribuo.ensemble.BaggingTrainer; 022import org.tribuo.ensemble.EnsembleCombiner; 023 024import java.util.logging.Logger; 025 026/** 027 * A trainer which produces a random forest. 028 * <p> 029 * Random Forests are basically bagged trees, with feature subsampling at each of the nodes. 030 * It's up to the user to supply a decision tree trainer which has feature subsampling turned on by 031 * checking {@link DecisionTreeTrainer#getFractionFeaturesInSplit()}. 032 * <p> 033 * See: 034 * <pre> 035 * J. Friedman, T. Hastie, & R. Tibshirani. 036 * "The Elements of Statistical Learning" 037 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 038 * </pre> 039 */ 040public class RandomForestTrainer<T extends Output<T>> extends BaggingTrainer<T> { 041 042 private static final Logger logger = Logger.getLogger(RandomForestTrainer.class.getName()); 043 044 /** 045 * For the configuration system. 046 */ 047 private RandomForestTrainer() { } 048 049 /** 050 * Constructs a RandomForestTrainer with the default seed {@link org.tribuo.Trainer#DEFAULT_SEED}. 051 * <p> 052 * Throws {@link PropertyException} if the trainer is not set to subsample the features. 053 * @param trainer The tree trainer. 054 * @param combiner The combining function for the ensemble. 055 * @param numMembers The number of ensemble members to train. 056 */ 057 public RandomForestTrainer(DecisionTreeTrainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) { 058 super(trainer,combiner,numMembers); 059 } 060 061 /** 062 * Constructs a RandomForestTrainer with the supplied seed, trainer, combining function and number of members. 063 * <p> 064 * Throws {@link PropertyException} if the trainer is not set to subsample the features. 065 * @param trainer The tree trainer. 066 * @param combiner The combining function for the ensemble. 067 * @param numMembers The number of ensemble members to train. 068 * @param seed The RNG seed. 069 */ 070 public RandomForestTrainer(DecisionTreeTrainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) { 071 super(trainer,combiner,numMembers,seed); 072 } 073 074 /** 075 * Used by the OLCUT configuration system, and should not be called by external code. 076 */ 077 @Override 078 public void postConfig() { 079 super.postConfig(); 080 if (!(innerTrainer instanceof DecisionTreeTrainer)) { 081 throw new PropertyException("","innerTrainer","RandomForestTrainer requires a decision tree innerTrainer"); 082 } 083 DecisionTreeTrainer<T> t = (DecisionTreeTrainer<T>) innerTrainer; 084 if (t.getFractionFeaturesInSplit() == 1f) { 085 throw new PropertyException("","innerTrainer","RandomForestTrainer requires that the decision tree " + 086 "innerTrainer have fractional features in split."); 087 } 088 } 089 090 @Override 091 protected String ensembleName() { 092 return "random-forest-ensemble"; 093 } 094 095 @Override 096 public String toString() { 097 StringBuilder buffer = new StringBuilder(); 098 099 buffer.append("RandomForestTrainer("); 100 buffer.append("innerTrainer="); 101 buffer.append(innerTrainer.toString()); 102 buffer.append(",combiner="); 103 buffer.append(combiner.toString()); 104 buffer.append(",numMembers="); 105 buffer.append(numMembers); 106 buffer.append(",seed="); 107 buffer.append(seed); 108 buffer.append(")"); 109 110 return buffer.toString(); 111 } 112 113}