001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.xgboost;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import ml.dmlc.xgboost4j.java.Booster;
021import ml.dmlc.xgboost4j.java.DMatrix;
022import ml.dmlc.xgboost4j.java.XGBoost;
023import ml.dmlc.xgboost4j.java.XGBoostError;
024import org.tribuo.Example;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Output;
029import org.tribuo.OutputFactory;
030import org.tribuo.Prediction;
031import org.tribuo.interop.ExternalDatasetProvenance;
032import org.tribuo.interop.ExternalModel;
033import org.tribuo.interop.ExternalTrainerProvenance;
034import org.tribuo.math.la.SparseVector;
035import org.tribuo.provenance.DatasetProvenance;
036import org.tribuo.provenance.ModelProvenance;
037
038import java.io.ByteArrayInputStream;
039import java.io.File;
040import java.io.IOException;
041import java.io.ObjectInputStream;
042import java.io.ObjectOutputStream;
043import java.net.MalformedURLException;
044import java.net.URL;
045import java.nio.file.Files;
046import java.nio.file.Path;
047import java.time.OffsetDateTime;
048import java.util.ArrayList;
049import java.util.Collections;
050import java.util.Comparator;
051import java.util.HashMap;
052import java.util.List;
053import java.util.Map;
054import java.util.PriorityQueue;
055import java.util.logging.Level;
056import java.util.logging.Logger;
057
058/**
059 * A {@link Model} which wraps around a XGBoost.Booster which was trained by a system other than Tribuo.
060 * <p>
061 * XGBoost is a fast implementation of gradient boosted decision trees.
062 * <p>
063 * Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception.
064 * <p>
065 * See:
066 * <pre>
067 * Chen T, Guestrin C.
068 * "XGBoost: A Scalable Tree Boosting System"
069 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016.
070 * </pre>
071 * <p>
072 * and for the original algorithm:
073 * <pre>
074 * Friedman JH.
075 * "Greedy Function Approximation: a Gradient Boosting Machine"
076 * Annals of statistics, 2001.
077 * </pre>
078 * <p>
079 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew),
080 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary
081 * on Maven Central.
082 */
083public final class XGBoostExternalModel<T extends Output<T>> extends ExternalModel<T,DMatrix,float[][]> {
084    private static final long serialVersionUID = 1L;
085
086    private static final Logger logger = Logger.getLogger(XGBoostExternalModel.class.getName());
087
088    private final XGBoostOutputConverter<T> converter;
089
090    /**
091     * Transient as we rely upon the native serialisation mechanism to bytes rather than Java serializing the Booster.
092     */
093    protected transient Booster model;
094
095    private XGBoostExternalModel(String name, ModelProvenance provenance,
096                                 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
097                                 Map<String, Integer> featureMapping, Booster model,
098                                 XGBoostOutputConverter<T> converter) {
099        super(name, provenance, featureIDMap, outputIDInfo, converter.generatesProbabilities(), featureMapping);
100        this.model = model;
101        this.converter = converter;
102    }
103
104    private XGBoostExternalModel(String name, ModelProvenance provenance,
105                                 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
106                                 int[] featureForwardMapping, int[] featureBackwardMapping,
107                                 Booster model, XGBoostOutputConverter<T> converter) {
108        super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping,
109              converter.generatesProbabilities());
110        this.model = model;
111        this.converter = converter;
112    }
113
114    @Override
115    protected DMatrix convertFeatures(SparseVector input) {
116        try {
117            return XGBoostTrainer.convertSparseVector(input);
118        } catch (XGBoostError e) {
119            logger.severe("XGBoost threw an error while constructing the DMatrix.");
120            throw new IllegalStateException(e);
121        }
122    }
123
124    @Override
125    protected DMatrix convertFeaturesList(List<SparseVector> input) {
126        try {
127            return XGBoostTrainer.convertSparseVectors(input);
128        } catch (XGBoostError e) {
129            logger.severe("XGBoost threw an error while constructing the DMatrix.");
130            throw new IllegalStateException(e);
131        }
132    }
133
134    @Override
135    protected float[][] externalPrediction(DMatrix input) {
136        try {
137            return model.predict(input);
138        } catch (XGBoostError e) {
139            logger.severe("XGBoost threw an error while predicting.");
140            throw new IllegalStateException(e);
141        }
142    }
143
144    @Override
145    protected Prediction<T> convertOutput(float[][] output, int numValidFeatures, Example<T> example) {
146        return converter.convertOutput(outputIDInfo,Collections.singletonList(output[0]),numValidFeatures,example);
147    }
148
149    @SuppressWarnings("unchecked") // generic array creation
150    @Override
151    protected List<Prediction<T>> convertOutput(float[][] output, int[] numValidFeatures, List<Example<T>> examples) {
152        return converter.convertBatchOutput(outputIDInfo,Collections.singletonList(output),numValidFeatures,(Example<T>[])examples.toArray(new Example[0]));
153    }
154
155    @Override
156    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
157        try {
158            int maxFeatures = n < 0 ? featureIDMap.size() : n;
159            Map<String, Integer> xgboostMap = model.getFeatureScore("");
160            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
161            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures,comparator);
162            //iterate over the scored features
163            for (Map.Entry<String, Integer> f : xgboostMap.entrySet()) {
164                int id = Integer.parseInt(f.getKey().substring(1));
165                Pair<String,Double> cur = new Pair<>(featureIDMap.get(featureBackwardMapping[id]).getName(), (double) f.getValue());
166
167                if (q.size() < maxFeatures) {
168                    q.offer(cur);
169                } else if (comparator.compare(cur,q.peek()) > 0) {
170                    q.poll();
171                    q.offer(cur);
172                }
173            }
174            List<Pair<String,Double>> list = new ArrayList<>();
175            while(q.size() > 0) {
176                list.add(q.poll());
177            }
178            Collections.reverse(list);
179
180            Map<String, List<Pair<String,Double>>> map = new HashMap<>();
181            map.put(Model.ALL_OUTPUTS,list);
182
183            return map;
184        } catch (XGBoostError e) {
185            logger.log(Level.SEVERE, "XGBoost threw an error", e);
186            return Collections.emptyMap();
187        }
188    }
189
190    @Override
191    protected XGBoostExternalModel<T> copy(String newName, ModelProvenance newProvenance) {
192        return new XGBoostExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo,
193                                          featureForwardMapping, featureBackwardMapping,
194                                          XGBoostModel.copyModel(model), converter);
195    }
196
197    /**
198     * Creates an {@code XGBoostExternalModel} from the supplied model on disk.
199     * @param factory The output factory to use.
200     * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids.
201     * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids.
202     * @param outputFunc The XGBoostOutputConverter function for the output type.
203     * @param path The path to the model on disk.
204     * @param <T> The type of the output.
205     * @return An XGBoostExternalModel ready to score new inputs.
206     */
207    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, String path) {
208        try {
209            Booster model = XGBoost.loadModel(path);
210            return createXGBoostModel(factory,featureMapping,outputMapping,outputFunc,model,new File(path).toURI().toURL());
211        } catch (XGBoostError | MalformedURLException e) {
212            throw new IllegalArgumentException("Unable to load model from path " + path, e);
213        }
214    }
215
216    /**
217     * Creates an {@code XGBoostExternalModel} from the supplied model on disk.
218     * @param factory The output factory to use.
219     * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids.
220     * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids.
221     * @param outputFunc The XGBoostOutputConverter function for the output type.
222     * @param path The path to the model on disk.
223     * @param <T> The type of the output.
224     * @return An XGBoostExternalModel ready to score new inputs.
225     */
226    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Path path) {
227        try {
228            Booster model = XGBoost.loadModel(Files.newInputStream(path));
229            return createXGBoostModel(factory,featureMapping,outputMapping,outputFunc,model,path.toUri().toURL());
230        } catch (XGBoostError | IOException e) {
231            throw new IllegalArgumentException("Unable to load model from path " + path, e);
232        }
233    }
234
235    /**
236     * Creates an {@code XGBoostExternalModel} from the supplied model.
237     * <p>
238     * Note: the provenance system requires that the URL point to a valid local file and
239     * will throw an exception if it is not. However it doesn't check that the file is
240     * where the Booster was created from.
241     * We will replace this entry point with one that accepts useful provenance information
242     * for an in-memory {@code Booster} object in a future release, and deprecate this
243     * endpoint at that time.
244     * @param factory The output factory to use.
245     * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids.
246     * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids.
247     * @param outputFunc The XGBoostOutputConverter function for the output type.
248     * @param model The XGBoost model to wrap.
249     * @param provenanceLocation The location where the model was loaded from.
250     * @param <T> The type of the output.
251     * @return An XGBoostExternalModel ready to score new inputs.
252     */
253    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Booster model, URL provenanceLocation) {
254        //TODO: add a new version of this method which accepts useful instance provenance information and deprecate this one
255        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
256        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping);
257        OffsetDateTime now = OffsetDateTime.now();
258        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
259        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size());
260        ModelProvenance provenance = new ModelProvenance(XGBoostExternalModel.class.getName(),now,datasetProvenance,trainerProvenance);
261        return new XGBoostExternalModel<>("external-model",provenance,featureMap,outputInfo,
262                                          featureMapping,model,outputFunc);
263    }
264
265    private void writeObject(ObjectOutputStream out) throws IOException {
266        out.defaultWriteObject();
267        try {
268            byte[] serialisedBooster = model.toByteArray();
269            out.writeObject(serialisedBooster);
270        } catch (XGBoostError e) {
271            throw new IOException("Failed to serialize the XGBoost model",e);
272        }
273    }
274
275    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
276        in.defaultReadObject();
277        try {
278            // Now read in the byte array and rebuild the Booster
279            byte[] serialisedBooster = (byte[]) in.readObject();
280            model = XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster));
281        } catch (XGBoostError e) {
282            throw new IOException("Failed to deserialize the XGBoost model",e);
283        }
284    }
285}