001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.experiments; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import org.tribuo.Trainer; 021import org.tribuo.classification.ClassificationOptions; 022import org.tribuo.classification.Label; 023import org.tribuo.classification.dtree.CARTClassificationOptions; 024import org.tribuo.classification.ensemble.ClassificationEnsembleOptions; 025import org.tribuo.classification.liblinear.LibLinearOptions; 026import org.tribuo.classification.libsvm.LibSVMOptions; 027import org.tribuo.classification.mnb.MultinomialNaiveBayesOptions; 028import org.tribuo.classification.sgd.kernel.KernelSVMOptions; 029import org.tribuo.classification.sgd.linear.LinearSGDOptions; 030import org.tribuo.classification.xgboost.XGBoostOptions; 031import org.tribuo.common.nearest.KNNClassifierOptions; 032import org.tribuo.hash.HashingOptions; 033import org.tribuo.hash.HashingOptions.ModelHashingType; 034 035import java.util.logging.Logger; 036 037/** 038 * Aggregates all the classification algorithms. 039 */ 040public class AllTrainerOptions implements ClassificationOptions<Trainer<Label>> { 041 private static final Logger logger = Logger.getLogger(AllTrainerOptions.class.getName()); 042 043 public enum AlgorithmType { 044 CART, KNN, LIBLINEAR, LIBSVM, MNB, SGD_KERNEL, SGD_LINEAR, XGBOOST //RANDOM_FOREST, 045 } 046 047 @Option(longName = "algorithm", usage = "Type of learner (or base learner). Defaults to SGD_LINEAR.") 048 public AlgorithmType algorithm = AlgorithmType.SGD_LINEAR; 049 050 public CARTClassificationOptions cartOptions; 051 public KNNClassifierOptions knnOptions; 052 public LibLinearOptions liblinearOptions; 053 public LibSVMOptions libsvmOptions; 054 public MultinomialNaiveBayesOptions mnbOptions; 055 public KernelSVMOptions kernelSVMOptions; 056 public LinearSGDOptions linearSGDOptions; 057 public XGBoostOptions xgBoostOptions; 058 059 public ClassificationEnsembleOptions ensemble; 060 public HashingOptions hashingOptions; 061 062 @Override 063 public Trainer<Label> getTrainer() { 064 Trainer<Label> trainer; 065 logger.info("Using " + algorithm); 066 switch (algorithm) { 067 case CART: 068 trainer = cartOptions.getTrainer(); 069 break; 070 case KNN: 071 trainer = knnOptions.getTrainer(); 072 break; 073 case LIBLINEAR: 074 trainer = liblinearOptions.getTrainer(); 075 break; 076 case LIBSVM: 077 trainer = libsvmOptions.getTrainer(); 078 break; 079 case MNB: 080 trainer = mnbOptions.getTrainer(); 081 break; 082 case SGD_KERNEL: 083 trainer = kernelSVMOptions.getTrainer(); 084 break; 085 case SGD_LINEAR: 086 trainer = linearSGDOptions.getTrainer(); 087 break; 088 case XGBOOST: 089 trainer = xgBoostOptions.getTrainer(); 090 break; 091 default: 092 throw new IllegalArgumentException("Unknown classifier " + algorithm); 093 } 094 095 if ((ensemble.ensembleSize > 0) && (ensemble.type != null)) { 096 switch (algorithm) { 097 case XGBOOST: 098 throw new IllegalArgumentException( 099 "Not allowed to ensemble XGBoost models. Why ensemble an ensemble?"); 100 default: 101 trainer = ensemble.wrapTrainer(trainer); 102 break; 103 } 104 } 105 106 if (hashingOptions.modelHashingAlgorithm != ModelHashingType.NONE) { 107 trainer = hashingOptions.getHashedTrainer(trainer); 108 } 109 logger.info("Trainer description " + trainer.toString()); 110 return trainer; 111 } 112 113}