001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.xgboost; 018 019import com.oracle.labs.mlrg.olcut.util.MutableDouble; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import ml.dmlc.xgboost4j.java.Booster; 022import ml.dmlc.xgboost4j.java.XGBoost; 023import ml.dmlc.xgboost4j.java.XGBoostError; 024import org.tribuo.Dataset; 025import org.tribuo.Example; 026import org.tribuo.Excuse; 027import org.tribuo.ImmutableFeatureMap; 028import org.tribuo.ImmutableOutputInfo; 029import org.tribuo.Model; 030import org.tribuo.Output; 031import org.tribuo.Prediction; 032import org.tribuo.common.xgboost.XGBoostTrainer.DMatrixTuple; 033import org.tribuo.provenance.ModelProvenance; 034 035import java.io.ByteArrayInputStream; 036import java.io.IOException; 037import java.io.ObjectInputStream; 038import java.io.ObjectOutputStream; 039import java.util.ArrayList; 040import java.util.Collections; 041import java.util.Comparator; 042import java.util.HashMap; 043import java.util.List; 044import java.util.Map; 045import java.util.Optional; 046import java.util.PriorityQueue; 047import java.util.logging.Level; 048import java.util.logging.Logger; 049 050/** 051 * A {@link Model} which wraps around a XGBoost.Booster. 052 * <p> 053 * XGBoost is a fast implementation of gradient boosted decision trees. 054 * <p> 055 * Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception. 056 * <p> 057 * See: 058 * <pre> 059 * Chen T, Guestrin C. 060 * "XGBoost: A Scalable Tree Boosting System" 061 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016. 062 * </pre> 063 * and for the original algorithm: 064 * <pre> 065 * Friedman JH. 066 * "Greedy Function Approximation: a Gradient Boosting Machine" 067 * Annals of statistics, 2001. 068 * </pre> 069 * <p> 070 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), 071 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary 072 * on Maven Central. 073 */ 074public final class XGBoostModel<T extends Output<T>> extends Model<T> { 075 private static final long serialVersionUID = 4L; 076 077 private static final Logger logger = Logger.getLogger(XGBoostModel.class.getName()); 078 079 private final XGBoostOutputConverter<T> converter; 080 081 /** 082 * The XGBoost4J Boosters. 083 */ 084 protected transient List<Booster> models; 085 086 XGBoostModel(String name, ModelProvenance description, 087 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> labelIDMap, 088 List<Booster> models, XGBoostOutputConverter<T> converter) { 089 super(name,description,featureIDMap,labelIDMap,converter.generatesProbabilities()); 090 this.converter = converter; 091 this.models = models; 092 } 093 094 /** 095 * Returns an unmodifiable list containing a copy of each model. 096 * <p> 097 * As XGBoost4J models don't expose a copy constructor this requires 098 * serializing each model to a byte array and rebuilding it, and is thus quite expensive. 099 * @return A copy of all of the models. 100 */ 101 public List<Booster> getInnerModels() { 102 List<Booster> copy = new ArrayList<>(); 103 104 for (Booster m : models) { 105 copy.add(copyModel(m)); 106 } 107 108 return Collections.unmodifiableList(copy); 109 } 110 111 /** 112 * Sets the number of threads to use at prediction time. 113 * <p> 114 * If set to 0 sets nthreads = num hardware threads. 115 * @param threads The new number of threads. 116 */ 117 public void setNumThreads(int threads) { 118 if (threads > -1) { 119 try { 120 for (Booster model : models) { 121 model.setParam("nthread", threads); 122 } 123 } catch (XGBoostError e) { 124 logger.log(Level.SEVERE, "XGBoost threw an error", e); 125 throw new IllegalStateException(e); 126 } 127 } 128 } 129 130 /** 131 * Uses the model to predict the labels for multiple examples contained in 132 * a data set. 133 * @param examples the data set containing the examples to predict. 134 * @return the results of the predictions, in the same order as the 135 * data set generates the example. 136 */ 137 @Override 138 public List<Prediction<T>> predict(Dataset<T> examples) { 139 return predict(examples.getData()); 140 } 141 142 /** 143 * Uses the model to predict the label for multiple examples. 144 * @param examples the examples to predict. 145 * @return the results of the prediction, in the same order as the 146 * examples. 147 */ 148 @Override 149 public List<Prediction<T>> predict(Iterable<Example<T>> examples) { 150 try { 151 DMatrixTuple<T> testMatrix = XGBoostTrainer.convertExamples(examples,featureIDMap); 152 List<float[][]> outputs = new ArrayList<>(); 153 for (Booster model : models) { 154 outputs.add(model.predict(testMatrix.data)); 155 } 156 157 int[] numValidFeatures = testMatrix.numValidFeatures; 158 Example<T>[] exampleArray = testMatrix.examples; 159 return converter.convertBatchOutput(outputIDInfo,outputs,numValidFeatures,exampleArray); 160 } catch (XGBoostError e) { 161 logger.log(Level.SEVERE, "XGBoost threw an error", e); 162 throw new IllegalStateException(e); 163 } 164 165 } 166 167 @Override 168 public Prediction<T> predict(Example<T> example) { 169 try { 170 DMatrixTuple<T> testData = XGBoostTrainer.convertExample(example,featureIDMap); 171 List<float[]> outputs = new ArrayList<>(); 172 for (Booster model : models) { 173 outputs.add(model.predict(testData.data)[0]); 174 } 175 Prediction<T> pred = converter.convertOutput(outputIDInfo,outputs,testData.numValidFeatures[0],example); 176 return pred; 177 } catch (XGBoostError e) { 178 logger.log(Level.SEVERE, "XGBoost threw an error", e); 179 throw new IllegalStateException(e); 180 } 181 } 182 183 @Override 184 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 185 try { 186 int maxFeatures = n < 0 ? featureIDMap.size() : n; 187 // Aggregate feature scores across all the models. 188 // This throws away model specific information which is useful in the case of regression, 189 // but it's very tricky to get the dimension name associated with the model. 190 Map<String, MutableDouble> outputMap = new HashMap<>(); 191 for (Booster model : models) { 192 Map<String, Integer> xgboostMap = model.getFeatureScore(""); 193 for (Map.Entry<String,Integer> f : xgboostMap.entrySet()) { 194 int id = Integer.parseInt(f.getKey().substring(1)); 195 String name = featureIDMap.get(id).getName(); 196 MutableDouble curVal = outputMap.computeIfAbsent(name,(k)->new MutableDouble()); 197 curVal.increment(f.getValue()); 198 } 199 } 200 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 201 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures,comparator); 202 for (Map.Entry<String,MutableDouble> e : outputMap.entrySet()) { 203 Pair<String,Double> cur = new Pair<>(e.getKey(), e.getValue().doubleValue()); 204 205 if (q.size() < maxFeatures) { 206 q.offer(cur); 207 } else if (comparator.compare(cur,q.peek()) > 0) { 208 q.poll(); 209 q.offer(cur); 210 } 211 } 212 List<Pair<String,Double>> list = new ArrayList<>(); 213 while(q.size() > 0) { 214 list.add(q.poll()); 215 } 216 Collections.reverse(list); 217 218 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 219 map.put(Model.ALL_OUTPUTS,list); 220 221 return map; 222 } catch (XGBoostError e) { 223 logger.log(Level.SEVERE, "XGBoost threw an error", e); 224 return Collections.emptyMap(); 225 } 226 } 227 228 /** 229 * Returns the string model dumps from each Booster. 230 * @return The model dumps. 231 */ 232 public List<String[]> getModelDump() { 233 try { 234 List<String[]> list = new ArrayList<>(); 235 for (Booster m : models) { 236 list.add(m.getModelDump("", true)); 237 } 238 return list; 239 } catch (XGBoostError e) { 240 throw new IllegalStateException(e); 241 } 242 } 243 244 @Override 245 public Optional<Excuse<T>> getExcuse(Example<T> example) { 246 return Optional.empty(); 247 } 248 249 /** 250 * Copies a single XGBoost Booster by serializing and deserializing it. 251 * @param booster The booster to copy. 252 * @return A deep copy of the booster. 253 */ 254 static Booster copyModel(Booster booster) { 255 try { 256 byte[] serialisedBooster = booster.toByteArray(); 257 return XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster)); 258 } catch (XGBoostError | IOException e) { 259 throw new IllegalStateException("Unable to copy XGBoost model.",e); 260 } 261 } 262 263 @Override 264 protected Model<T> copy(String newName, ModelProvenance newProvenance) { 265 List<Booster> newModels = new ArrayList<>(); 266 for (Booster model : models) { 267 newModels.add(copyModel(model)); 268 } 269 return new XGBoostModel<>(newName, newProvenance, featureIDMap, outputIDInfo, newModels, converter); 270 } 271 272 private void writeObject(ObjectOutputStream out) throws IOException { 273 out.defaultWriteObject(); 274 try { 275 out.writeInt(models.size()); 276 for (Booster model : models) { 277 byte[] serialisedBooster = model.toByteArray(); 278 out.writeObject(serialisedBooster); 279 } 280 } catch (XGBoostError e) { 281 throw new IOException("Failed to serialize the XGBoost model",e); 282 } 283 } 284 285 private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { 286 in.defaultReadObject(); 287 try { 288 models = new ArrayList<>(); 289 int numModels = in.readInt(); 290 for (int i = 0; i < numModels; i++) { 291 // Now read in the byte array and rebuild each Booster 292 byte[] serialisedBooster = (byte[]) in.readObject(); 293 models.add(XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster))); 294 } 295 } catch (XGBoostError e) { 296 throw new IOException("Failed to deserialize the XGBoost model",e); 297 } 298 } 299}