001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Trainer;
026import org.tribuo.common.xgboost.XGBoostModel;
027import org.tribuo.common.xgboost.XGBoostTrainer;
028import org.tribuo.provenance.ModelProvenance;
029import org.tribuo.provenance.TrainerProvenance;
030import org.tribuo.provenance.impl.TrainerProvenanceImpl;
031import org.tribuo.regression.Regressor;
032import ml.dmlc.xgboost4j.java.Booster;
033import ml.dmlc.xgboost4j.java.XGBoost;
034import ml.dmlc.xgboost4j.java.XGBoostError;
035
036import java.time.OffsetDateTime;
037import java.util.ArrayList;
038import java.util.Collections;
039import java.util.List;
040import java.util.Map;
041import java.util.logging.Level;
042import java.util.logging.Logger;
043
044/**
045 * A {@link Trainer} which wraps the XGBoost training procedure.
046 * This only exposes a few of XGBoost's training parameters.
047 * It uses pthreads outside of the JVM to parallelise the computation.
048 * <p>
049 * Each output dimension is trained independently (and so contains a separate XGBoost ensemble).
050 * <p>
051 * See:
052 * <pre>
053 * Chen T, Guestrin C.
054 * "XGBoost: A Scalable Tree Boosting System"
055 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
056 * </pre>
057 * and for the original algorithm:
058 * <pre>
059 * Friedman JH.
060 * "Greedy Function Approximation: a Gradient Boosting Machine"
061 * Annals of statistics, 2001.
062 * </pre>
063 * <p>
064 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew),
065 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary
066 * on Maven Central.
067 */
068public final class XGBoostRegressionTrainer extends XGBoostTrainer<Regressor> {
069
070    private static final Logger logger = Logger.getLogger(XGBoostRegressionTrainer.class.getName());
071
072    /**
073     * Types of regression loss.
074     */
075    public enum RegressionType {
076        /**
077         * Squared error loss function.
078         */
079        LINEAR("reg:squarederror"),
080        /**
081         * Gamma loss function.
082         */
083        GAMMA("reg:gamma"),
084        /**
085         * Tweedie loss function.
086         */
087        TWEEDIE("reg:tweedie");
088
089        public final String paramName;
090
091        RegressionType(String paramName) {
092            this.paramName = paramName;
093        }
094    }
095
096    @Config(description="The type of regression.")
097    private RegressionType rType = RegressionType.LINEAR;
098
099    public XGBoostRegressionTrainer(int numTrees) {
100        this(RegressionType.LINEAR, numTrees);
101    }
102
103    public XGBoostRegressionTrainer(RegressionType rType, int numTrees) {
104        this(rType, numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED);
105    }
106
107    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, int numThreads, boolean silent) {
108        this(rType, numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED);
109    }
110
111    /**
112     * Create an XGBoost trainer.
113     *
114     * @param rType The type of regression to build.
115     * @param numTrees Number of trees to boost.
116     * @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
117     * @param gamma Minimum loss reduction to make a split (default 0, range
118     * [0,inf]).
119     * @param maxDepth Maximum tree depth (default 6, range [1,inf]).
120     * @param minChildWeight Minimum sum of instance weights needed in a leaf
121     * (default 1, range [0, inf]).
122     * @param subsample Subsample size for each tree (default 1, range (0,1]).
123     * @param featureSubsample Subsample features for each tree (default 1,
124     * range (0,1]).
125     * @param lambda L2 regularization term on weights (default 1).
126     * @param alpha L1 regularization term on weights (default 0).
127     * @param nThread Number of threads to use (default 4).
128     * @param silent Silence the training output text.
129     * @param seed RNG seed.
130     */
131    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
132        super(numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent,seed);
133        this.rType = rType;
134
135        postConfig();
136    }
137
138    /**
139     * This gives direct access to the XGBoost parameter map.
140     * <p>
141     * It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
142     * <p>
143     * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
144     * @param rType The type of the regression.
145     * @param numTrees Number of trees to boost.
146     * @param parameters A map from string to object, where object can be Number or String.
147     */
148    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, Map<String,Object> parameters) {
149        super(numTrees,parameters);
150        this.rType = rType;
151        postConfig();
152    }
153
154    /**
155     * For olcut.
156     */
157    private XGBoostRegressionTrainer() { }
158
159    /**
160     * Used by the OLCUT configuration system, and should not be called by external code.
161     */
162    @Override
163    public void postConfig() {
164        super.postConfig();
165        parameters.put("objective",rType.paramName);
166    }
167
168    @Override
169    public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
170        if (examples.getOutputInfo().getUnknownCount() > 0) {
171            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
172        }
173        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
174        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
175        int numOutputs = outputInfo.size();
176        TrainerProvenance trainerProvenance = getProvenance();
177        trainInvocationCounter++;
178        List<Booster> models = new ArrayList<>();
179        try {
180            // Use a null response extractor as we'll do the per dimension regression extraction later.
181            DMatrixTuple<Regressor> trainingData = convertExamples(examples, featureMap, null);
182
183            // Extract the weights and the regression targets.
184            float[][] outputs = new float[numOutputs][examples.size()];
185            float[] weights = new float[examples.size()];
186            int i = 0;
187            for (Example<Regressor> e : examples) {
188                weights[i] = e.getWeight();
189                double[] curOutputs = e.getOutput().getValues();
190                // Transpose them for easy training.
191                for (int j = 0; j < numOutputs; j++) {
192                    outputs[j][i] = (float) curOutputs[j];
193                }
194                i++;
195            }
196            trainingData.data.setWeight(weights);
197
198            // Finished setup, now train one model per dimension.
199            for (i = 0; i < numOutputs; i++) {
200                trainingData.data.setLabel(outputs[i]);
201                models.add(XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null));
202            }
203        } catch (XGBoostError e) {
204            logger.log(Level.SEVERE, "XGBoost threw an error", e);
205            throw new IllegalStateException(e);
206        }
207
208        ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
209        XGBoostModel<Regressor> xgModel = createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter());
210
211        return xgModel;
212    }
213
214    @Override
215    public TrainerProvenance getProvenance() {
216        return new TrainerProvenanceImpl(this);
217    }
218}