001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.xgboost;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.Feature;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Output;
027import org.tribuo.Trainer;
028import org.tribuo.WeightedExamples;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.provenance.SkeletalTrainerProvenance;
033import org.tribuo.util.Util;
034import ml.dmlc.xgboost4j.java.Booster;
035import ml.dmlc.xgboost4j.java.DMatrix;
036import ml.dmlc.xgboost4j.java.XGBoostError;
037
038import java.util.ArrayList;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.function.Function;
043import java.util.logging.Logger;
044
045/**
046 * A {@link Trainer} which wraps the XGBoost training procedure.
047 * <p>
048 * This only exposes a few of XGBoost's training parameters.
049 * <p>
050 * It uses pthreads outside of the JVM to parallelise the computation.
051 * <p>
052 * See:
053 * <pre>
054 * Chen T, Guestrin C.
055 * "XGBoost: A Scalable Tree Boosting System"
056 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
057 * </pre>
058 * and for the original algorithm:
059 * <pre>
060 * Friedman JH.
061 * "Greedy Function Approximation: a Gradient Boosting Machine"
062 * Annals of statistics, 2001.
063 * </pre>
064 * N.B.: This uses a native C implementation of xgboost that links to various C libraries, including libgomp
065 * and glibc. If you're running on Alpine, which does not natively use glibc, you'll need to install glibc
066 * into the container. On Windows this binary is not available in the Maven Central release, you'll need
067 * to compile it from source.
068 */
069public abstract class XGBoostTrainer<T extends Output<T>> implements Trainer<T>, WeightedExamples {
070    /* Alpine install command
071     * <pre>
072     *    $ apk --no-cache add ca-certificates wget
073     *    $ wget -q -O /etc/apk/keys/sgerrand.rsa.pub https://alpine-pkgs.sgerrand.com/sgerrand.rsa.pub
074     *    $ wget https://github.com/sgerrand/alpine-pkg-glibc/releases/download/2.30-r0/glibc-2.30-r0.apk
075     *    $ apk add glibc-2.30-r0.apk
076     * </pre>
077     */
078
079    private static final Logger logger = Logger.getLogger(XGBoostTrainer.class.getName());
080
081    protected final Map<String, Object> parameters = new HashMap<>();
082
083    /**
084     * The type of XGBoost model.
085     */
086    public enum BoosterType {
087        /**
088         * A boosted linear model.
089         */
090        LINEAR("gblinear"),
091        /**
092         * A gradient boosted decision tree.
093         */
094        GBTREE("gbtree"),
095        /**
096         * A gradient boosted decision tree using dropout.
097         */
098        DART("dart");
099
100        public final String paramName;
101
102        BoosterType(String paramName) {
103            this.paramName = paramName;
104        }
105    }
106
107    @Config(mandatory = true,description="The number of trees to build.")
108    protected int numTrees;
109
110    @Config(description = "The learning rate, shrinks the new tree output to prevent overfitting.")
111    private double eta = 0.3;
112
113    @Config(description = "Minimum loss reduction needed to split a tree node.")
114    private double gamma = 0.0;
115
116    @Config(description="The maximum depth of any tree.")
117    private int maxDepth = 6;
118
119    @Config(description = "The minimum weight in each child node before a split is valid.")
120    private double minChildWeight = 1.0;
121
122    @Config(description="Independently subsample the examples for each tree.")
123    private double subsample = 1.0;
124
125    @Config(description="Independently subsample the features available for each node of each tree.")
126    private double featureSubsample = 1.0;
127
128    @Config(description="l2 regularisation term on the weights.")
129    private double lambda = 1.0;
130
131    @Config(description="l1 regularisation term on the weights.")
132    private double alpha = 1.0;
133
134    @Config(description="The number of threads to use at training time.")
135    private int nThread = 4;
136
137    @Config(description="Quiesce all the logging output from the XGBoost C library.")
138    private int silent = 1;
139
140    @Config(description="Type of the weak learner.")
141    private BoosterType booster = BoosterType.GBTREE;
142
143    @Config(description="The RNG seed.")
144    private long seed = Trainer.DEFAULT_SEED;
145
146    protected int trainInvocationCounter = 0;
147
148    protected XGBoostTrainer(int numTrees) {
149        this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, 4, true, Trainer.DEFAULT_SEED);
150    }
151
152    protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) {
153        this(numTrees, 0.3, 0, 6, 1, 1, 1, 1, 0, numThreads, silent, Trainer.DEFAULT_SEED);
154    }
155
156    /**
157     * Create an XGBoost trainer.
158     *
159     * @param numTrees Number of trees to boost.
160     * @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
161     * @param gamma Minimum loss reduction to make a split (default 0, range
162     * [0,inf]).
163     * @param maxDepth Maximum tree depth (default 6, range [1,inf]).
164     * @param minChildWeight Minimum sum of instance weights needed in a leaf
165     * (default 1, range [0, inf]).
166     * @param subsample Subsample size for each tree (default 1, range (0,1]).
167     * @param featureSubsample Subsample features for each tree (default 1,
168     * range (0,1]).
169     * @param lambda L2 regularization term on weights (default 1).
170     * @param alpha L1 regularization term on weights (default 0).
171     * @param nThread Number of threads to use (default 4).
172     * @param silent Silence the training output text.
173     * @param seed RNG seed.
174     */
175    protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
176        if (numTrees < 1) {
177            throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees);
178        }
179        this.numTrees = numTrees;
180        this.eta = eta;
181        this.gamma = gamma;
182        this.maxDepth = maxDepth;
183        this.minChildWeight = minChildWeight;
184        this.subsample = subsample;
185        this.featureSubsample = featureSubsample;
186        this.lambda = lambda;
187        this.alpha = alpha;
188        this.nThread = nThread;
189        this.silent = silent ? 1 : 0;
190        this.seed = seed;
191    }
192
193    /**
194     * This gives direct access to the XGBoost parameter map.
195     * <p>
196     * It lets you pick things that we haven't exposed like dropout trees, binary classification etc.
197     * <p>
198     * This sidesteps the validation that Tribuo provides for the hyperparameters, and so can produce unexpected results.
199     * @param numTrees Number of trees to boost.
200     * @param parameters A map from string to object, where object can be Number or String.
201     */
202    protected XGBoostTrainer(int numTrees, Map<String,Object> parameters) {
203        if (numTrees < 1) {
204            throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees);
205        }
206        this.numTrees = numTrees;
207        this.parameters.putAll(parameters);
208    }
209
210    /**
211     * For olcut.
212     */
213    protected XGBoostTrainer() { }
214
215    /**
216     * Used by the OLCUT configuration system, and should not be called by external code.
217     */
218    @Override
219    public void postConfig() {
220        parameters.put("eta", eta);
221        parameters.put("gamma", gamma);
222        parameters.put("max_depth", maxDepth);
223        parameters.put("min_child_weight", minChildWeight);
224        parameters.put("subsample", subsample);
225        parameters.put("colsample_bytree", featureSubsample);
226        parameters.put("lambda", lambda);
227        parameters.put("alpha", alpha);
228        parameters.put("nthread", nThread);
229        parameters.put("seed", seed);
230        parameters.put("silent", silent);
231        parameters.put("booster", booster.paramName);
232    }
233
234    @Override
235    public String toString() {
236        StringBuilder buffer = new StringBuilder();
237
238        buffer.append("XGBoostTrainer(numTrees=");
239        buffer.append(numTrees);
240        buffer.append(",parameters");
241        buffer.append(parameters.toString());
242        buffer.append(")");
243
244        return buffer.toString();
245    }
246
247    protected XGBoostModel<T> createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Booster> models, XGBoostOutputConverter<T> converter) {
248        return new XGBoostModel<>(name,provenance,featureIDMap,outputIDInfo,models,converter);
249    }
250
251    @Override
252    public int getInvocationCount() {
253        return trainInvocationCounter;
254    }
255
256    protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> examples, Function<T,Float> responseExtractor) throws XGBoostError {
257        return convertExamples(examples.getData(), examples.getFeatureIDMap(), responseExtractor);
258    }
259
260    protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> examples) throws XGBoostError {
261        return convertExamples(examples.getData(), examples.getFeatureIDMap(), null);
262    }
263
264    protected static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap) throws XGBoostError {
265        return convertExamples(examples, featureMap, null);
266    }
267
268    /**
269     * Converts an iterable of examples into a DMatrix.
270     * @param examples The examples to convert.
271     * @param featureMap The feature id map which supplies the indices.
272     * @param responseExtractor The extraction function for the output.
273     * @param <T> The type of the output.
274     * @return A DMatrixTuple.
275     * @throws XGBoostError If the native library failed to construct the DMatrix.
276     */
277    protected static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> examples, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws XGBoostError {
278        // headers = array of start points for a row
279        // indices = array of feature indices for all data
280        // data = array of feature values for all data
281        // SparseType = DMatrix.SparseType.CSR
282        //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError
283        //
284        // then call
285        //public void setLabel(float[] labels) throws XGBoostError
286
287        boolean labelled = responseExtractor != null;
288        ArrayList<Float> labelsList = new ArrayList<>();
289        ArrayList<Float> dataList = new ArrayList<>();
290        ArrayList<Long> headersList = new ArrayList<>();
291        ArrayList<Integer> indicesList = new ArrayList<>();
292        ArrayList<Float> weightsList = new ArrayList<>();
293        ArrayList<Integer> numValidFeatures = new ArrayList<>();
294        ArrayList<Example<T>> examplesList = new ArrayList<>();
295
296        long rowHeader = 0;
297        headersList.add(rowHeader);
298        for (Example<T> e : examples) {
299            if (labelled) {
300                labelsList.add(responseExtractor.apply(e.getOutput()));
301                weightsList.add(e.getWeight());
302            }
303            examplesList.add(e);
304            long newRowHeader = convertSingleExample(e,featureMap,dataList,indicesList,headersList,rowHeader);
305            numValidFeatures.add((int) (newRowHeader-rowHeader));
306            rowHeader = newRowHeader;
307        }
308
309        float[] data = Util.toPrimitiveFloat(dataList);
310        int[] indices = Util.toPrimitiveInt(indicesList);
311        long[] headers = Util.toPrimitiveLong(headersList);
312
313        DMatrix dataMatrix = new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,featureMap.size());
314        if (labelled) {
315            float[] labels = Util.toPrimitiveFloat(labelsList);
316            dataMatrix.setLabel(labels);
317            float[] weights = Util.toPrimitiveFloat(weightsList);
318            dataMatrix.setWeight(weights);
319        }
320        @SuppressWarnings("unchecked") // Generic array creation
321        Example<T>[] exampleArray = (Example<T>[])examplesList.toArray(new Example[0]);
322        return new DMatrixTuple<>(dataMatrix,Util.toPrimitiveInt(numValidFeatures),exampleArray);
323    }
324
325    protected static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap) throws XGBoostError {
326        return convertExample(example,featureMap,null);
327    }
328
329    /**
330     * Converts an examples into a DMatrix.
331     * @param example The example to convert.
332     * @param featureMap The feature id map which supplies the indices.
333     * @param responseExtractor The extraction function for the output.
334     * @param <T> The type of the output.
335     * @return A DMatrixTuple.
336     * @throws XGBoostError If the native library failed to construct the DMatrix.
337     */
338    protected static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap featureMap, Function<T,Float> responseExtractor) throws XGBoostError {
339        // headers = array of start points for a row
340        // indices = array of feature indices for all data
341        // data = array of feature values for all data
342        // SparseType = DMatrix.SparseType.CSR
343        //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError
344        //
345        // then call
346        //public void setLabel(float[] labels) throws XGBoostError
347
348        boolean labelled = responseExtractor != null;
349        ArrayList<Float> dataList = new ArrayList<>();
350        ArrayList<Integer> indicesList = new ArrayList<>();
351        ArrayList<Long> headersList = new ArrayList<>();
352        headersList.add(0L);
353
354        long header = convertSingleExample(example,featureMap,dataList,indicesList,headersList,0);
355
356        float[] data = Util.toPrimitiveFloat(dataList);
357        int[] indices = Util.toPrimitiveInt(indicesList);
358        long[] headers = Util.toPrimitiveLong(headersList);
359
360        DMatrix dataMatrix = new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,featureMap.size());
361        if (labelled) {
362            float[] labels = new float[1];
363            labels[0] = responseExtractor.apply(example.getOutput());
364            dataMatrix.setLabel(labels);
365            float[] weights = new float[1];
366            weights[0] = example.getWeight();
367            dataMatrix.setWeight(weights);
368        }
369        @SuppressWarnings("unchecked") // Generic array creation
370        Example<T>[] exampleArray = (Example<T>[])new Example[]{example};
371        return new DMatrixTuple<>(dataMatrix,new int[]{(int)header},exampleArray);
372    }
373
374    /**
375     * Writes out the features from an example into the three supplied {@link ArrayList}s.
376     * <p>
377     * This is used to transform examples into the right format for an XGBoost call.
378     * It's used in both the Classification and Regression XGBoost backends.
379     * The ArrayLists must be non-null, and can contain existing values (as this
380     * method is called multiple times to build up an arraylist containing all the
381     * feature values for a dataset).
382     * <p>
383     * Features with colliding feature ids are summed together.
384     * <p>
385     * Can throw IllegalArgumentException if the {@link Example} contains no features.
386     * @param example The example to inspect.
387     * @param featureMap The feature map of the model/dataset (used to preserve hash information).
388     * @param dataList The output feature values.
389     * @param indicesList The output indices.
390     * @param headersList The output header position (an integer saying how long each sparse example is).
391     * @param header The current header position.
392     * @param <T> The type of the example.
393     * @return The updated header position.
394     */
395    protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap featureMap, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) {
396        int numActiveFeatures = 0;
397        int prevIdx = -1;
398        int indicesSize = indicesList.size();
399        for (Feature f : example) {
400            int id = featureMap.getID(f.getName());
401            if (id > prevIdx){
402                prevIdx = id;
403                dataList.add((float) f.getValue());
404                indicesList.add(id);
405                numActiveFeatures++;
406            } else if (id > -1) {
407                //
408                // Collision, deal with it.
409                int collisionIdx = Util.binarySearch(indicesList,id,indicesSize,numActiveFeatures+indicesSize);
410                if (collisionIdx < 0) {
411                    //
412                    // Collision but not present in tmpIndices
413                    // move data and bump i
414                    collisionIdx = - (collisionIdx + 1);
415                    indicesList.add(collisionIdx,id);
416                    dataList.add(collisionIdx,(float) f.getValue());
417                    numActiveFeatures++;
418                } else {
419                    //
420                    // Collision present in tmpIndices
421                    // add the values.
422                    dataList.set(collisionIdx, dataList.get(collisionIdx) + (float) f.getValue());
423                }
424            }
425        }
426        if (numActiveFeatures == 0) {
427            throw new IllegalArgumentException("No features found in Example " + example.toString());
428        }
429        header += numActiveFeatures;
430        headersList.add(header);
431        return header;
432    }
433
434    /**
435     * Writes out the features from a SparseVector into the three supplied {@link ArrayList}s.
436     * <p>
437     * This is used to transform examples into the right format for an XGBoost call.
438     * It's used when predicting with an externally trained XGBoost model, as the
439     * external training may not respect Tribuo's feature ordering constraints.
440     * The ArrayLists must be non-null, and can contain existing values (as this
441     * method is called multiple times to build up an arraylist containing all the
442     * feature values for a dataset).
443     * </p>
444     * <p>
445     * This is much simpler than {@link XGBoostTrainer#convertSingleExample} as the validation
446     * of feature indices is done in the {@link org.tribuo.interop.ExternalModel} class.
447     * </p>
448     * @param vector The features to convert.
449     * @param dataList The output feature values.
450     * @param indicesList The output indices.
451     * @param headersList The output header position (an integer saying how long each sparse example is).
452     * @param header The current header position.
453     * @return The updated header position.
454     */
455    static long convertSingleExample(SparseVector vector, ArrayList<Float> dataList, ArrayList<Integer> indicesList, ArrayList<Long> headersList, long header) {
456        int numActiveFeatures = 0;
457        for (VectorTuple v : vector) {
458            dataList.add((float) v.value);
459            indicesList.add(v.index);
460            numActiveFeatures++;
461        }
462        header += numActiveFeatures;
463        headersList.add(header);
464        return header;
465    }
466
467    /**
468     * Used when predicting with an externally trained XGBoost model.
469     * @param vector The features to convert.
470     * @return A DMatrix representing the features.
471     * @throws XGBoostError If the native library returns an error state.
472     */
473    protected static DMatrix convertSparseVector(SparseVector vector) throws XGBoostError {
474        // headers = array of start points for a row
475        // indices = array of feature indices for all data
476        // data = array of feature values for all data
477        // SparseType = DMatrix.SparseType.CSR
478        //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError
479        ArrayList<Float> dataList = new ArrayList<>();
480        ArrayList<Long> headersList = new ArrayList<>();
481        ArrayList<Integer> indicesList = new ArrayList<>();
482
483        long rowHeader = 0;
484        headersList.add(rowHeader);
485        convertSingleExample(vector,dataList,indicesList,headersList,rowHeader);
486
487        float[] data = Util.toPrimitiveFloat(dataList);
488        int[] indices = Util.toPrimitiveInt(indicesList);
489        long[] headers = Util.toPrimitiveLong(headersList);
490
491        return new DMatrix(headers, indices, data, DMatrix.SparseType.CSR,vector.size());
492    }
493
494    /**
495     * Used when predicting with an externally trained XGBoost model.
496     * <p>
497     * It is assumed all vectors are the same size when passed into this function.
498     * @param vectors The batch of features to convert.
499     * @return A DMatrix representing the batch of features.
500     * @throws XGBoostError If the native library returns an error state.
501     */
502    protected static DMatrix convertSparseVectors(List<SparseVector> vectors) throws XGBoostError {
503        // headers = array of start points for a row
504        // indices = array of feature indices for all data
505        // data = array of feature values for all data
506        // SparseType = DMatrix.SparseType.CSR
507        //public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError
508        ArrayList<Float> dataList = new ArrayList<>();
509        ArrayList<Long> headersList = new ArrayList<>();
510        ArrayList<Integer> indicesList = new ArrayList<>();
511
512        int numFeatures = 0;
513        long rowHeader = 0;
514        headersList.add(rowHeader);
515        for (SparseVector e : vectors) {
516            rowHeader = convertSingleExample(e,dataList,indicesList,headersList,rowHeader);
517            numFeatures = e.size(); // All vectors are assumed to be the same size.
518        }
519
520        float[] data = Util.toPrimitiveFloat(dataList);
521        int[] indices = Util.toPrimitiveInt(indicesList);
522        long[] headers = Util.toPrimitiveLong(headersList);
523
524        return new DMatrix(headers, indices, data, DMatrix.SparseType.CSR, numFeatures);
525    }
526
527    /**
528     * Tuple of a DMatrix, the number of valid features in each example, and the examples themselves.
529     * <p>
530     * One day it'll be a record.
531     * @param <T> The output type.
532     */
533    protected static class DMatrixTuple<T extends Output<T>> {
534        public final DMatrix data;
535        public final int[] numValidFeatures;
536        public final Example<T>[] examples;
537
538        public DMatrixTuple(DMatrix data, int[] numValidFeatures, Example<T>[] examples) {
539            this.data = data;
540            this.numValidFeatures = numValidFeatures;
541            this.examples = examples;
542        }
543    }
544
545    /**
546     * Provenance for {@link XGBoostTrainer}. No longer used.
547     */
548    @Deprecated
549    protected static class XGBoostTrainerProvenance extends SkeletalTrainerProvenance {
550        private static final long serialVersionUID = 1L;
551
552        protected <T extends Output<T>> XGBoostTrainerProvenance(XGBoostTrainer<T> host) {
553            super(host);
554        }
555
556        protected XGBoostTrainerProvenance(Map<String,Provenance> map) {
557            super(map);
558        }
559    }
560}