001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Trainer;
022import org.tribuo.classification.Label;
023import org.tribuo.common.tree.DecisionTreeTrainer;
024import org.tribuo.common.tree.RandomForestTrainer;
025import org.tribuo.ensemble.BaggingTrainer;
026
027import java.util.logging.Logger;
028
029/**
030 * Options for building a classification ensemble.
031 */
032public class ClassificationEnsembleOptions implements Options {
033    private static final Logger logger = Logger.getLogger(ClassificationEnsembleOptions.class.getName());
034    public enum EnsembleType { ADABOOST, BAGGING, RF }
035
036    @Option(longName="ensemble-type",usage="Ensemble method, options are {ADABOOST, BAGGING, RF}.")
037    public EnsembleType type = EnsembleType.BAGGING;
038    @Option(longName="ensemble-size",usage="Number of base learners in the ensemble.")
039    public int ensembleSize = -1;
040    @Option(longName="ensemble-seed",usage="RNG seed.")
041    public long seed = Trainer.DEFAULT_SEED;
042
043
044    public Trainer<Label> wrapTrainer(Trainer<Label> trainer) {
045        if ((ensembleSize > 0) && (type != null)) {
046            switch (type) {
047                case ADABOOST:
048                    logger.info("Using Adaboost with " + ensembleSize + " members.");
049                    return new AdaBoostTrainer(trainer,ensembleSize,seed);
050                case BAGGING:
051                    logger.info("Using Bagging with " + ensembleSize + " members.");
052                    return new BaggingTrainer<>(trainer,new VotingCombiner(),ensembleSize,seed);
053                case RF:
054                    if (trainer instanceof DecisionTreeTrainer) {
055                        logger.info("Using Random Forests with " + ensembleSize + " members.");
056                        return new RandomForestTrainer<>((DecisionTreeTrainer<Label>)trainer,new VotingCombiner(),ensembleSize,seed);
057                    } else {
058                        throw new IllegalArgumentException("RandomForestTrainer requires a DecisionTreeTrainer");
059                    }
060                default:
061                    throw new IllegalArgumentException("Unknown ensemble type :" + type);
062            }
063        } else {
064            return trainer;
065        }
066    }
067}