001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import org.tribuo.Trainer; 021import org.tribuo.classification.ClassificationOptions; 022 023/** 024 * CLI options for training an XGBoost classifier. 025 */ 026public class XGBoostOptions implements ClassificationOptions<XGBoostClassificationTrainer> { 027 @Option(longName = "xgb-ensemble-size", usage = "Number of trees in the ensemble.") 028 public int xgbEnsembleSize = -1; 029 @Option(longName = "xgb-alpha", usage = "L1 regularization term for weights (default 0).") 030 public float xbgAlpha = 0.0f; 031 @Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).") 032 public float xgbMinWeight = 1; 033 @Option(longName = "xgb-max-depth", usage = "Max tree depth (default 6, range (0,inf]).") 034 public int xgbMaxDepth = 6; 035 @Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (default 0.3, range [0,1]).") 036 public float xgbEta = 0.3f; 037 @Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (default 1, range (0,1]).") 038 public float xgbSubsampleFeatures; 039 @Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (default 0, range [0,inf]).") 040 public float xgbGamma = 0.0f; 041 @Option(longName = "xgb-lambda", usage = "L2 regularization term for weights (default 1).") 042 public float xgbLambda = 1.0f; 043 @Option(longName = "xgb-quiet", usage = "Make the XGBoost training procedure quiet.") 044 public boolean xgbQuiet; 045 @Option(longName = "xgb-subsample", usage = "Subsample size for each tree (default 1, range (0,1]).") 046 public float xgbSubsample = 1.0f; 047 @Option(longName = "xgb-num-threads", usage = "Number of threads to use (default 4, range (1, num hw threads)).") 048 public int xgbNumThreads; 049 @Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.") 050 private long xgbSeed = Trainer.DEFAULT_SEED; 051 052 @Override 053 public XGBoostClassificationTrainer getTrainer() { 054 if (xgbEnsembleSize == -1) { 055 throw new IllegalArgumentException("Please supply the number of trees."); 056 } 057 return new XGBoostClassificationTrainer(xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbQuiet, xgbSeed); 058 } 059}