001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Trainer;
022import org.tribuo.regression.xgboost.XGBoostRegressionTrainer.RegressionType;
023
024/**
025 * CLI options for configuring an XGBoost regression trainer.
026 */
027public class XGBoostOptions implements Options {
028    @Option(longName="xgb-regression-metric", usage="Regression type to use. Defaults to LINEAR.")
029    public RegressionType rType = RegressionType.LINEAR;
030    @Option(longName="xgb-ensemble-size",usage="Number of trees in the ensemble.")
031    public int ensembleSize = -1;
032    @Option(longName="xgb-alpha",usage="L1 regularization term for weights (default 0).")
033    public float alpha = 0.0f;
034    @Option(longName="xgb-min-weight",usage="Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
035    public float minWeight = 1;
036    @Option(longName="xgb-max-depth",usage="Max tree depth (default 6, range (0,inf]).")
037    public int depth = 6;
038    @Option(longName="xgb-eta",usage="Step size shrinkage parameter (default 0.3, range [0,1]).")
039    public float eta = 0.3f;
040    @Option(longName="xgb-subsample-features",usage="Subsample features for each tree (default 1, range (0,1]).")
041    public float subsampleFeatures;
042    @Option(longName="xgb-gamma",usage="Minimum loss reduction to make a split (default 0, range [0,inf]).")
043    public float gamma = 0.0f;
044    @Option(longName="xgb-lambda",usage="L2 regularization term for weights (default 1).")
045    public float lambda = 1.0f;
046    @Option(longName="xgb-quiet",usage="Make the XGBoost training procedure quiet.")
047    public boolean quiet;
048    @Option(longName="xgb-subsample",usage="Subsample size for each tree (default 1, range (0,1]).")
049    public float subsample = 1.0f;
050    @Option(longName="xgb-num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).")
051    public int numThreads;
052    @Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.")
053    private long seed = Trainer.DEFAULT_SEED;
054
055    /**
056     * Gets the configured XGBoostRegressionTrainer.
057     * @return The configured trainer.
058     */
059    public XGBoostRegressionTrainer getTrainer() {
060        return new XGBoostRegressionTrainer(rType,ensembleSize,eta,gamma,depth,minWeight,subsample,subsampleFeatures,lambda,alpha,numThreads,quiet,seed);
061    }
062}