001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import org.tribuo.Trainer; 022import org.tribuo.classification.Label; 023import org.tribuo.common.tree.DecisionTreeTrainer; 024import org.tribuo.common.tree.RandomForestTrainer; 025import org.tribuo.ensemble.BaggingTrainer; 026 027import java.util.logging.Logger; 028 029/** 030 * Options for building a classification ensemble. 031 */ 032public class ClassificationEnsembleOptions implements Options { 033 private static final Logger logger = Logger.getLogger(ClassificationEnsembleOptions.class.getName()); 034 public enum EnsembleType { ADABOOST, BAGGING, RF } 035 036 @Option(longName="ensemble-type",usage="Ensemble method, options are {ADABOOST, BAGGING, RF}.") 037 public EnsembleType type = EnsembleType.BAGGING; 038 @Option(longName="ensemble-size",usage="Number of base learners in the ensemble.") 039 public int ensembleSize = -1; 040 @Option(longName="ensemble-seed",usage="RNG seed.") 041 public long seed = Trainer.DEFAULT_SEED; 042 043 044 public Trainer<Label> wrapTrainer(Trainer<Label> trainer) { 045 if ((ensembleSize > 0) && (type != null)) { 046 switch (type) { 047 case ADABOOST: 048 logger.info("Using Adaboost with " + ensembleSize + " members."); 049 return new AdaBoostTrainer(trainer,ensembleSize,seed); 050 case BAGGING: 051 logger.info("Using Bagging with " + ensembleSize + " members."); 052 return new BaggingTrainer<>(trainer,new VotingCombiner(),ensembleSize,seed); 053 case RF: 054 if (trainer instanceof DecisionTreeTrainer) { 055 logger.info("Using Random Forests with " + ensembleSize + " members."); 056 return new RandomForestTrainer<>((DecisionTreeTrainer<Label>)trainer,new VotingCombiner(),ensembleSize,seed); 057 } else { 058 throw new IllegalArgumentException("RandomForestTrainer requires a DecisionTreeTrainer"); 059 } 060 default: 061 throw new IllegalArgumentException("Unknown ensemble type :" + type); 062 } 063 } else { 064 return trainer; 065 } 066 } 067}