001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.experiments;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import org.tribuo.Trainer;
021import org.tribuo.classification.ClassificationOptions;
022import org.tribuo.classification.Label;
023import org.tribuo.classification.dtree.CARTClassificationOptions;
024import org.tribuo.classification.ensemble.ClassificationEnsembleOptions;
025import org.tribuo.classification.liblinear.LibLinearOptions;
026import org.tribuo.classification.libsvm.LibSVMOptions;
027import org.tribuo.classification.mnb.MultinomialNaiveBayesOptions;
028import org.tribuo.classification.sgd.kernel.KernelSVMOptions;
029import org.tribuo.classification.sgd.linear.LinearSGDOptions;
030import org.tribuo.classification.xgboost.XGBoostOptions;
031import org.tribuo.common.nearest.KNNClassifierOptions;
032import org.tribuo.hash.HashingOptions;
033import org.tribuo.hash.HashingOptions.ModelHashingType;
034
035import java.util.logging.Logger;
036
037/**
038 * Aggregates all the classification algorithms.
039 */
040public class AllTrainerOptions implements ClassificationOptions<Trainer<Label>> {
041    private static final Logger logger = Logger.getLogger(AllTrainerOptions.class.getName());
042
043    public enum AlgorithmType {
044        CART, KNN, LIBLINEAR, LIBSVM, MNB, SGD_KERNEL, SGD_LINEAR, XGBOOST //RANDOM_FOREST,
045    }
046
047    @Option(longName = "algorithm", usage = "Type of learner (or base learner). Defaults to SGD_LINEAR.")
048    public AlgorithmType algorithm = AlgorithmType.SGD_LINEAR;
049
050    public CARTClassificationOptions cartOptions;
051    public KNNClassifierOptions knnOptions;
052    public LibLinearOptions liblinearOptions;
053    public LibSVMOptions libsvmOptions;
054    public MultinomialNaiveBayesOptions mnbOptions;
055    public KernelSVMOptions kernelSVMOptions;
056    public LinearSGDOptions linearSGDOptions;
057    public XGBoostOptions xgBoostOptions;
058
059    public ClassificationEnsembleOptions ensemble;
060    public HashingOptions hashingOptions;
061
062    @Override
063    public Trainer<Label> getTrainer() {
064        Trainer<Label> trainer;
065        logger.info("Using " + algorithm);
066        switch (algorithm) {
067            case CART:
068                trainer = cartOptions.getTrainer();
069                break;
070            case KNN:
071                trainer = knnOptions.getTrainer();
072                break;
073            case LIBLINEAR:
074                trainer = liblinearOptions.getTrainer();
075                break;
076            case LIBSVM:
077                trainer = libsvmOptions.getTrainer();
078                break;
079            case MNB:
080                trainer = mnbOptions.getTrainer();
081                break;
082            case SGD_KERNEL:
083                trainer = kernelSVMOptions.getTrainer();
084                break;
085            case SGD_LINEAR:
086                trainer = linearSGDOptions.getTrainer();
087                break;
088            case XGBOOST:
089                trainer = xgBoostOptions.getTrainer();
090                break;
091            default:
092                throw new IllegalArgumentException("Unknown classifier " + algorithm);
093        }
094
095        if ((ensemble.ensembleSize > 0) && (ensemble.type != null)) {
096            switch (algorithm) {
097                case XGBOOST:
098                    throw new IllegalArgumentException(
099                            "Not allowed to ensemble XGBoost models. Why ensemble an ensemble?");
100                default:
101                    trainer = ensemble.wrapTrainer(trainer);
102                    break;
103            }
104        }
105
106        if (hashingOptions.modelHashingAlgorithm != ModelHashingType.NONE) {
107            trainer = hashingOptions.getHashedTrainer(trainer);
108        }
109        logger.info("Trainer description " + trainer.toString());
110        return trainer;
111    }
112
113}