001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Trainer;
025import org.tribuo.classification.Label;
026import org.tribuo.common.xgboost.XGBoostModel;
027import org.tribuo.common.xgboost.XGBoostTrainer;
028import org.tribuo.provenance.ModelProvenance;
029import org.tribuo.provenance.TrainerProvenance;
030import org.tribuo.provenance.impl.TrainerProvenanceImpl;
031import ml.dmlc.xgboost4j.java.Booster;
032import ml.dmlc.xgboost4j.java.XGBoost;
033import ml.dmlc.xgboost4j.java.XGBoostError;
034
035import java.time.OffsetDateTime;
036import java.util.Collections;
037import java.util.Map;
038import java.util.function.Function;
039import java.util.logging.Level;
040import java.util.logging.Logger;
041
042/**
043 * A {@link Trainer} which wraps the XGBoost training procedure.
044 * <p>
045 * This only exposes a few of XGBoost's training parameters.
046 * <p>
047 * It uses pthreads outside of the JVM to parallelise the computation.
048 * <p>
049 * See:
050 * <pre>
051 * Chen T, Guestrin C.
052 * "XGBoost: A Scalable Tree Boosting System"
053 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
054 * </pre>
055 * and for the original algorithm:
056 * <pre>
057 * Friedman JH.
058 * "Greedy Function Approximation: a Gradient Boosting Machine"
059 * Annals of statistics, 2001.
060 * </pre>
061 * <p>
062 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew),
063 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary
064 * on Maven Central.
065 */
066public final class XGBoostClassificationTrainer extends XGBoostTrainer<Label> {
067
068    private static final Logger logger = Logger.getLogger(XGBoostClassificationTrainer.class.getName());
069
070    @Config(description="Evaluation metric to use. The default value is set based on the objective function, so this can be usually left blank.")
071    private String evalMetric = "";
072
073    public XGBoostClassificationTrainer(int numTrees) {
074        this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED);
075    }
076
077    public XGBoostClassificationTrainer(int numTrees, int numThreads, boolean silent) {
078        this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED);
079    }
080
081    /**
082     * Create an XGBoost trainer.
083     *
084     * @param numTrees Number of trees to boost.
085     * @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
086     * @param gamma Minimum loss reduction to make a split (default 0, range
087     * [0,inf]).
088     * @param maxDepth Maximum tree depth (default 6, range [1,inf]).
089     * @param minChildWeight Minimum sum of instance weights needed in a leaf
090     * (default 1, range [0, inf]).
091     * @param subsample Subsample size for each tree (default 1, range (0,1]).
092     * @param featureSubsample Subsample features for each tree (default 1,
093     * range (0,1]).
094     * @param lambda L2 regularization term on weights (default 1).
095     * @param alpha L1 regularization term on weights (default 0).
096     * @param nThread Number of threads to use (default 4).
097     * @param silent Silence the training output text.
098     * @param seed RNG seed.
099     */
100    public XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
101        super(numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent,seed);
102        postConfig();
103    }
104
105    /**
106     * This gives direct access to the XGBoost parameter map.
107     * <p>
108     * It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
109     * <p>
110     * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
111     * @param numTrees Number of trees to boost.
112     * @param parameters A map from string to object, where object can be Number or String.
113     */
114    public XGBoostClassificationTrainer(int numTrees, Map<String,Object> parameters) {
115        super(numTrees,parameters);
116        postConfig();
117    }
118
119    /**
120     * For olcut.
121     */
122    protected XGBoostClassificationTrainer() { }
123
124    /**
125     * Used by the OLCUT configuration system, and should not be called by external code.
126     */
127    @Override
128    public void postConfig() {
129        super.postConfig();
130        parameters.put("objective", "multi:softprob");
131        if(!evalMetric.isEmpty()) {
132            parameters.put("eval_metric", evalMetric);
133        }
134    }
135
136    @Override
137    public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
138        if (examples.getOutputInfo().getUnknownCount() > 0) {
139            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
140        }
141        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
142        ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo();
143        TrainerProvenance trainerProvenance = getProvenance();
144        trainInvocationCounter++;
145        parameters.put("num_class", outputInfo.size());
146        Booster model;
147        Function<Label,Float> responseExtractor = (Label l) -> (float) outputInfo.getID(l);
148        try {
149            DMatrixTuple<Label> trainingData = convertExamples(examples, featureMap, responseExtractor);
150            model = XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null);
151        } catch (XGBoostError e) {
152            logger.log(Level.SEVERE, "XGBoost threw an error", e);
153            throw new IllegalStateException(e);
154        }
155
156        ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
157        XGBoostModel<Label> xgModel = createModel("xgboost-classification-model", provenance, featureMap, outputInfo, Collections.singletonList(model), new XGBoostClassificationConverter());
158
159        return xgModel;
160    }
161
162    @Override
163    public TrainerProvenance getProvenance() {
164        return new TrainerProvenanceImpl(this);
165    }
166}