001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.xgboost;
018
019import com.oracle.labs.mlrg.olcut.util.MutableDouble;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import ml.dmlc.xgboost4j.java.Booster;
022import ml.dmlc.xgboost4j.java.XGBoost;
023import ml.dmlc.xgboost4j.java.XGBoostError;
024import org.tribuo.Dataset;
025import org.tribuo.Example;
026import org.tribuo.Excuse;
027import org.tribuo.ImmutableFeatureMap;
028import org.tribuo.ImmutableOutputInfo;
029import org.tribuo.Model;
030import org.tribuo.Output;
031import org.tribuo.Prediction;
032import org.tribuo.common.xgboost.XGBoostTrainer.DMatrixTuple;
033import org.tribuo.provenance.ModelProvenance;
034
035import java.io.ByteArrayInputStream;
036import java.io.IOException;
037import java.io.ObjectInputStream;
038import java.io.ObjectOutputStream;
039import java.util.ArrayList;
040import java.util.Collections;
041import java.util.Comparator;
042import java.util.HashMap;
043import java.util.List;
044import java.util.Map;
045import java.util.Optional;
046import java.util.PriorityQueue;
047import java.util.logging.Level;
048import java.util.logging.Logger;
049
050/**
051 * A {@link Model} which wraps around a XGBoost.Booster.
052 * <p>
053 * XGBoost is a fast implementation of gradient boosted decision trees.
054 * <p>
055 * Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception.
056 * <p>
057 * See:
058 * <pre>
059 * Chen T, Guestrin C.
060 * "XGBoost: A Scalable Tree Boosting System"
061 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
062 * </pre>
063 * and for the original algorithm:
064 * <pre>
065 * Friedman JH.
066 * "Greedy Function Approximation: a Gradient Boosting Machine"
067 * Annals of statistics, 2001.
068 * </pre>
069 * <p>
070 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew),
071 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary
072 * on Maven Central.
073 */
074public final class XGBoostModel<T extends Output<T>> extends Model<T> {
075    private static final long serialVersionUID = 4L;
076
077    private static final Logger logger = Logger.getLogger(XGBoostModel.class.getName());
078
079    private final XGBoostOutputConverter<T> converter;
080
081    /**
082     * The XGBoost4J Boosters.
083     */
084    protected transient List<Booster> models;
085
086    XGBoostModel(String name, ModelProvenance description,
087                 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> labelIDMap,
088                 List<Booster> models, XGBoostOutputConverter<T> converter) {
089        super(name,description,featureIDMap,labelIDMap,converter.generatesProbabilities());
090        this.converter = converter;
091        this.models = models;
092    }
093
094    /**
095     * Returns an unmodifiable list containing a copy of each model.
096     * <p>
097     * As XGBoost4J models don't expose a copy constructor this requires
098     * serializing each model to a byte array and rebuilding it, and is thus quite expensive.
099     * @return A copy of all of the models.
100     */
101    public List<Booster> getInnerModels() {
102        List<Booster> copy = new ArrayList<>();
103
104        for (Booster m : models) {
105            copy.add(copyModel(m));
106        }
107
108        return Collections.unmodifiableList(copy);
109    }
110
111    /**
112     * Sets the number of threads to use at prediction time.
113     * <p>
114     * If set to 0 sets nthreads = num hardware threads.
115     * @param threads The new number of threads.
116     */
117    public void setNumThreads(int threads) {
118        if (threads > -1) {
119            try {
120                for (Booster model : models) {
121                    model.setParam("nthread", threads);
122                }
123            } catch (XGBoostError e) {
124                logger.log(Level.SEVERE, "XGBoost threw an error", e);
125                throw new IllegalStateException(e);
126            }
127        }
128    }
129
130    /**
131     * Uses the model to predict the labels for multiple examples contained in
132     * a data set.
133     * @param examples the data set containing the examples to predict.
134     * @return the results of the predictions, in the same order as the
135     * data set generates the example.
136     */
137    @Override
138    public List<Prediction<T>> predict(Dataset<T> examples) {
139        return predict(examples.getData());
140    }
141
142    /**
143     * Uses the model to predict the label for multiple examples.
144     * @param examples the examples to predict.
145     * @return the results of the prediction, in the same order as the
146     * examples.
147     */
148    @Override
149    public List<Prediction<T>> predict(Iterable<Example<T>> examples) {
150        try {
151            DMatrixTuple<T> testMatrix = XGBoostTrainer.convertExamples(examples,featureIDMap);
152            List<float[][]> outputs = new ArrayList<>();
153            for (Booster model : models) {
154                outputs.add(model.predict(testMatrix.data));
155            }
156
157            int[] numValidFeatures = testMatrix.numValidFeatures;
158            Example<T>[] exampleArray = testMatrix.examples;
159            return converter.convertBatchOutput(outputIDInfo,outputs,numValidFeatures,exampleArray);
160        } catch (XGBoostError e) {
161            logger.log(Level.SEVERE, "XGBoost threw an error", e);
162            throw new IllegalStateException(e);
163        }
164
165    }
166
167    @Override
168    public Prediction<T> predict(Example<T> example) {
169        try {
170            DMatrixTuple<T> testData = XGBoostTrainer.convertExample(example,featureIDMap);
171            List<float[]> outputs = new ArrayList<>();
172            for (Booster model : models) {
173                outputs.add(model.predict(testData.data)[0]);
174            }
175            Prediction<T> pred = converter.convertOutput(outputIDInfo,outputs,testData.numValidFeatures[0],example);
176            return pred;
177        } catch (XGBoostError e) {
178            logger.log(Level.SEVERE, "XGBoost threw an error", e);
179            throw new IllegalStateException(e);
180        }
181    }
182
183    @Override
184    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
185        try {
186            int maxFeatures = n < 0 ? featureIDMap.size() : n;
187            // Aggregate feature scores across all the models.
188            // This throws away model specific information which is useful in the case of regression,
189            // but it's very tricky to get the dimension name associated with the model.
190            Map<String, MutableDouble> outputMap = new HashMap<>();
191            for (Booster model : models) {
192                Map<String, Integer> xgboostMap = model.getFeatureScore("");
193                for (Map.Entry<String,Integer> f : xgboostMap.entrySet()) {
194                    int id = Integer.parseInt(f.getKey().substring(1));
195                    String name = featureIDMap.get(id).getName();
196                    MutableDouble curVal = outputMap.computeIfAbsent(name,(k)->new MutableDouble());
197                    curVal.increment(f.getValue());
198                }
199            }
200            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
201            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures,comparator);
202            for (Map.Entry<String,MutableDouble> e : outputMap.entrySet()) {
203                Pair<String,Double> cur = new Pair<>(e.getKey(), e.getValue().doubleValue());
204
205                if (q.size() < maxFeatures) {
206                    q.offer(cur);
207                } else if (comparator.compare(cur,q.peek()) > 0) {
208                    q.poll();
209                    q.offer(cur);
210                }
211            }
212            List<Pair<String,Double>> list = new ArrayList<>();
213            while(q.size() > 0) {
214                list.add(q.poll());
215            }
216            Collections.reverse(list);
217
218            Map<String, List<Pair<String,Double>>> map = new HashMap<>();
219            map.put(Model.ALL_OUTPUTS,list);
220
221            return map;
222        } catch (XGBoostError e) {
223            logger.log(Level.SEVERE, "XGBoost threw an error", e);
224            return Collections.emptyMap();
225        }
226    }
227
228    /**
229     * Returns the string model dumps from each Booster.
230     * @return The model dumps.
231     */
232    public List<String[]> getModelDump() {
233        try {
234            List<String[]> list = new ArrayList<>();
235            for (Booster m : models) {
236                list.add(m.getModelDump("", true));
237            }
238            return list;
239        } catch (XGBoostError e) {
240            throw new IllegalStateException(e);
241        }
242    }
243
244    @Override
245    public Optional<Excuse<T>> getExcuse(Example<T> example) {
246        return Optional.empty();
247    }
248
249    /**
250     * Copies a single XGBoost Booster by serializing and deserializing it.
251     * @param booster The booster to copy.
252     * @return A deep copy of the booster.
253     */
254    static Booster copyModel(Booster booster) {
255        try {
256            byte[] serialisedBooster = booster.toByteArray();
257            return XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster));
258        } catch (XGBoostError | IOException e) {
259            throw new IllegalStateException("Unable to copy XGBoost model.",e);
260        }
261    }
262
263    @Override
264    protected Model<T> copy(String newName, ModelProvenance newProvenance) {
265        List<Booster> newModels = new ArrayList<>();
266        for (Booster model : models) {
267            newModels.add(copyModel(model));
268        }
269        return new XGBoostModel<>(newName, newProvenance, featureIDMap, outputIDInfo, newModels, converter);
270    }
271
272    private void writeObject(ObjectOutputStream out) throws IOException {
273        out.defaultWriteObject();
274        try {
275            out.writeInt(models.size());
276            for (Booster model : models) {
277                byte[] serialisedBooster = model.toByteArray();
278                out.writeObject(serialisedBooster);
279            }
280        } catch (XGBoostError e) {
281            throw new IOException("Failed to serialize the XGBoost model",e);
282        }
283    }
284
285    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
286        in.defaultReadObject();
287        try {
288            models = new ArrayList<>();
289            int numModels = in.readInt();
290            for (int i = 0; i < numModels; i++) {
291                // Now read in the byte array and rebuild each Booster
292                byte[] serialisedBooster = (byte[]) in.readObject();
293                models.add(XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster)));
294            }
295        } catch (XGBoostError e) {
296            throw new IOException("Failed to deserialize the XGBoost model",e);
297        }
298    }
299}