001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.xgboost; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import org.tribuo.Trainer; 022import org.tribuo.regression.xgboost.XGBoostRegressionTrainer.RegressionType; 023 024/** 025 * CLI options for configuring an XGBoost regression trainer. 026 */ 027public class XGBoostOptions implements Options { 028 @Option(longName="xgb-regression-metric", usage="Regression type to use. Defaults to LINEAR.") 029 public RegressionType rType = RegressionType.LINEAR; 030 @Option(longName="xgb-ensemble-size",usage="Number of trees in the ensemble.") 031 public int ensembleSize = -1; 032 @Option(longName="xgb-alpha",usage="L1 regularization term for weights (default 0).") 033 public float alpha = 0.0f; 034 @Option(longName="xgb-min-weight",usage="Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).") 035 public float minWeight = 1; 036 @Option(longName="xgb-max-depth",usage="Max tree depth (default 6, range (0,inf]).") 037 public int depth = 6; 038 @Option(longName="xgb-eta",usage="Step size shrinkage parameter (default 0.3, range [0,1]).") 039 public float eta = 0.3f; 040 @Option(longName="xgb-subsample-features",usage="Subsample features for each tree (default 1, range (0,1]).") 041 public float subsampleFeatures; 042 @Option(longName="xgb-gamma",usage="Minimum loss reduction to make a split (default 0, range [0,inf]).") 043 public float gamma = 0.0f; 044 @Option(longName="xgb-lambda",usage="L2 regularization term for weights (default 1).") 045 public float lambda = 1.0f; 046 @Option(longName="xgb-quiet",usage="Make the XGBoost training procedure quiet.") 047 public boolean quiet; 048 @Option(longName="xgb-subsample",usage="Subsample size for each tree (default 1, range (0,1]).") 049 public float subsample = 1.0f; 050 @Option(longName="xgb-num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).") 051 public int numThreads; 052 @Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.") 053 private long seed = Trainer.DEFAULT_SEED; 054 055 /** 056 * Gets the configured XGBoostRegressionTrainer. 057 * @return The configured trainer. 058 */ 059 public XGBoostRegressionTrainer getTrainer() { 060 return new XGBoostRegressionTrainer(rType,ensembleSize,eta,gamma,depth,minWeight,subsample,subsampleFeatures,lambda,alpha,numThreads,quiet,seed); 061 } 062}