001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import org.tribuo.Trainer;
021import org.tribuo.classification.ClassificationOptions;
022
023/**
024 * CLI options for training an XGBoost classifier.
025 */
026public class XGBoostOptions implements ClassificationOptions<XGBoostClassificationTrainer> {
027    @Option(longName = "xgb-ensemble-size", usage = "Number of trees in the ensemble.")
028    public int xgbEnsembleSize = -1;
029    @Option(longName = "xgb-alpha", usage = "L1 regularization term for weights (default 0).")
030    public float xbgAlpha = 0.0f;
031    @Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
032    public float xgbMinWeight = 1;
033    @Option(longName = "xgb-max-depth", usage = "Max tree depth (default 6, range (0,inf]).")
034    public int xgbMaxDepth = 6;
035    @Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (default 0.3, range [0,1]).")
036    public float xgbEta = 0.3f;
037    @Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (default 1, range (0,1]).")
038    public float xgbSubsampleFeatures;
039    @Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (default 0, range [0,inf]).")
040    public float xgbGamma = 0.0f;
041    @Option(longName = "xgb-lambda", usage = "L2 regularization term for weights (default 1).")
042    public float xgbLambda = 1.0f;
043    @Option(longName = "xgb-quiet", usage = "Make the XGBoost training procedure quiet.")
044    public boolean xgbQuiet;
045    @Option(longName = "xgb-subsample", usage = "Subsample size for each tree (default 1, range (0,1]).")
046    public float xgbSubsample = 1.0f;
047    @Option(longName = "xgb-num-threads", usage = "Number of threads to use (default 4, range (1, num hw threads)).")
048    public int xgbNumThreads;
049    @Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.")
050    private long xgbSeed = Trainer.DEFAULT_SEED;
051
052    @Override
053    public XGBoostClassificationTrainer getTrainer() {
054        if (xgbEnsembleSize == -1) {
055            throw new IllegalArgumentException("Please supply the number of trees.");
056        }
057        return new XGBoostClassificationTrainer(xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbQuiet, xgbSeed);
058    }
059}