001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.xgboost; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import ml.dmlc.xgboost4j.java.Booster; 021import ml.dmlc.xgboost4j.java.DMatrix; 022import ml.dmlc.xgboost4j.java.XGBoost; 023import ml.dmlc.xgboost4j.java.XGBoostError; 024import org.tribuo.Example; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Output; 029import org.tribuo.OutputFactory; 030import org.tribuo.Prediction; 031import org.tribuo.interop.ExternalDatasetProvenance; 032import org.tribuo.interop.ExternalModel; 033import org.tribuo.interop.ExternalTrainerProvenance; 034import org.tribuo.math.la.SparseVector; 035import org.tribuo.provenance.DatasetProvenance; 036import org.tribuo.provenance.ModelProvenance; 037 038import java.io.ByteArrayInputStream; 039import java.io.File; 040import java.io.IOException; 041import java.io.ObjectInputStream; 042import java.io.ObjectOutputStream; 043import java.net.MalformedURLException; 044import java.net.URL; 045import java.nio.file.Files; 046import java.nio.file.Path; 047import java.time.OffsetDateTime; 048import java.util.ArrayList; 049import java.util.Collections; 050import java.util.Comparator; 051import java.util.HashMap; 052import java.util.List; 053import java.util.Map; 054import java.util.PriorityQueue; 055import java.util.logging.Level; 056import java.util.logging.Logger; 057 058/** 059 * A {@link Model} which wraps around a XGBoost.Booster which was trained by a system other than Tribuo. 060 * <p> 061 * XGBoost is a fast implementation of gradient boosted decision trees. 062 * <p> 063 * Throws IllegalStateException if the XGBoost C++ library fails to load or throws an exception. 064 * <p> 065 * See: 066 * <pre> 067 * Chen T, Guestrin C. 068 * "XGBoost: A Scalable Tree Boosting System" 069 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016. 070 * </pre> 071 * <p> 072 * and for the original algorithm: 073 * <pre> 074 * Friedman JH. 075 * "Greedy Function Approximation: a Gradient Boosting Machine" 076 * Annals of statistics, 2001. 077 * </pre> 078 * <p> 079 * Note: XGBoost requires a native library, on macOS this library requires libomp (which can be installed via homebrew), 080 * on Windows this native library must be compiled into a jar as it's not contained in the official XGBoost binary 081 * on Maven Central. 082 */ 083public final class XGBoostExternalModel<T extends Output<T>> extends ExternalModel<T,DMatrix,float[][]> { 084 private static final long serialVersionUID = 1L; 085 086 private static final Logger logger = Logger.getLogger(XGBoostExternalModel.class.getName()); 087 088 private final XGBoostOutputConverter<T> converter; 089 090 /** 091 * Transient as we rely upon the native serialisation mechanism to bytes rather than Java serializing the Booster. 092 */ 093 protected transient Booster model; 094 095 private XGBoostExternalModel(String name, ModelProvenance provenance, 096 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 097 Map<String, Integer> featureMapping, Booster model, 098 XGBoostOutputConverter<T> converter) { 099 super(name, provenance, featureIDMap, outputIDInfo, converter.generatesProbabilities(), featureMapping); 100 this.model = model; 101 this.converter = converter; 102 } 103 104 private XGBoostExternalModel(String name, ModelProvenance provenance, 105 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 106 int[] featureForwardMapping, int[] featureBackwardMapping, 107 Booster model, XGBoostOutputConverter<T> converter) { 108 super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping, 109 converter.generatesProbabilities()); 110 this.model = model; 111 this.converter = converter; 112 } 113 114 @Override 115 protected DMatrix convertFeatures(SparseVector input) { 116 try { 117 return XGBoostTrainer.convertSparseVector(input); 118 } catch (XGBoostError e) { 119 logger.severe("XGBoost threw an error while constructing the DMatrix."); 120 throw new IllegalStateException(e); 121 } 122 } 123 124 @Override 125 protected DMatrix convertFeaturesList(List<SparseVector> input) { 126 try { 127 return XGBoostTrainer.convertSparseVectors(input); 128 } catch (XGBoostError e) { 129 logger.severe("XGBoost threw an error while constructing the DMatrix."); 130 throw new IllegalStateException(e); 131 } 132 } 133 134 @Override 135 protected float[][] externalPrediction(DMatrix input) { 136 try { 137 return model.predict(input); 138 } catch (XGBoostError e) { 139 logger.severe("XGBoost threw an error while predicting."); 140 throw new IllegalStateException(e); 141 } 142 } 143 144 @Override 145 protected Prediction<T> convertOutput(float[][] output, int numValidFeatures, Example<T> example) { 146 return converter.convertOutput(outputIDInfo,Collections.singletonList(output[0]),numValidFeatures,example); 147 } 148 149 @SuppressWarnings("unchecked") // generic array creation 150 @Override 151 protected List<Prediction<T>> convertOutput(float[][] output, int[] numValidFeatures, List<Example<T>> examples) { 152 return converter.convertBatchOutput(outputIDInfo,Collections.singletonList(output),numValidFeatures,(Example<T>[])examples.toArray(new Example[0])); 153 } 154 155 @Override 156 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 157 try { 158 int maxFeatures = n < 0 ? featureIDMap.size() : n; 159 Map<String, Integer> xgboostMap = model.getFeatureScore(""); 160 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 161 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures,comparator); 162 //iterate over the scored features 163 for (Map.Entry<String, Integer> f : xgboostMap.entrySet()) { 164 int id = Integer.parseInt(f.getKey().substring(1)); 165 Pair<String,Double> cur = new Pair<>(featureIDMap.get(featureBackwardMapping[id]).getName(), (double) f.getValue()); 166 167 if (q.size() < maxFeatures) { 168 q.offer(cur); 169 } else if (comparator.compare(cur,q.peek()) > 0) { 170 q.poll(); 171 q.offer(cur); 172 } 173 } 174 List<Pair<String,Double>> list = new ArrayList<>(); 175 while(q.size() > 0) { 176 list.add(q.poll()); 177 } 178 Collections.reverse(list); 179 180 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 181 map.put(Model.ALL_OUTPUTS,list); 182 183 return map; 184 } catch (XGBoostError e) { 185 logger.log(Level.SEVERE, "XGBoost threw an error", e); 186 return Collections.emptyMap(); 187 } 188 } 189 190 @Override 191 protected XGBoostExternalModel<T> copy(String newName, ModelProvenance newProvenance) { 192 return new XGBoostExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo, 193 featureForwardMapping, featureBackwardMapping, 194 XGBoostModel.copyModel(model), converter); 195 } 196 197 /** 198 * Creates an {@code XGBoostExternalModel} from the supplied model on disk. 199 * @param factory The output factory to use. 200 * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids. 201 * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids. 202 * @param outputFunc The XGBoostOutputConverter function for the output type. 203 * @param path The path to the model on disk. 204 * @param <T> The type of the output. 205 * @return An XGBoostExternalModel ready to score new inputs. 206 */ 207 public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, String path) { 208 try { 209 Booster model = XGBoost.loadModel(path); 210 return createXGBoostModel(factory,featureMapping,outputMapping,outputFunc,model,new File(path).toURI().toURL()); 211 } catch (XGBoostError | MalformedURLException e) { 212 throw new IllegalArgumentException("Unable to load model from path " + path, e); 213 } 214 } 215 216 /** 217 * Creates an {@code XGBoostExternalModel} from the supplied model on disk. 218 * @param factory The output factory to use. 219 * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids. 220 * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids. 221 * @param outputFunc The XGBoostOutputConverter function for the output type. 222 * @param path The path to the model on disk. 223 * @param <T> The type of the output. 224 * @return An XGBoostExternalModel ready to score new inputs. 225 */ 226 public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Path path) { 227 try { 228 Booster model = XGBoost.loadModel(Files.newInputStream(path)); 229 return createXGBoostModel(factory,featureMapping,outputMapping,outputFunc,model,path.toUri().toURL()); 230 } catch (XGBoostError | IOException e) { 231 throw new IllegalArgumentException("Unable to load model from path " + path, e); 232 } 233 } 234 235 /** 236 * Creates an {@code XGBoostExternalModel} from the supplied model. 237 * <p> 238 * Note: the provenance system requires that the URL point to a valid local file and 239 * will throw an exception if it is not. However it doesn't check that the file is 240 * where the Booster was created from. 241 * We will replace this entry point with one that accepts useful provenance information 242 * for an in-memory {@code Booster} object in a future release, and deprecate this 243 * endpoint at that time. 244 * @param factory The output factory to use. 245 * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids. 246 * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids. 247 * @param outputFunc The XGBoostOutputConverter function for the output type. 248 * @param model The XGBoost model to wrap. 249 * @param provenanceLocation The location where the model was loaded from. 250 * @param <T> The type of the output. 251 * @return An XGBoostExternalModel ready to score new inputs. 252 */ 253 public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String,Integer> featureMapping, Map<T,Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Booster model, URL provenanceLocation) { 254 //TODO: add a new version of this method which accepts useful instance provenance information and deprecate this one 255 ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet()); 256 ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping); 257 OffsetDateTime now = OffsetDateTime.now(); 258 ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation); 259 DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size()); 260 ModelProvenance provenance = new ModelProvenance(XGBoostExternalModel.class.getName(),now,datasetProvenance,trainerProvenance); 261 return new XGBoostExternalModel<>("external-model",provenance,featureMap,outputInfo, 262 featureMapping,model,outputFunc); 263 } 264 265 private void writeObject(ObjectOutputStream out) throws IOException { 266 out.defaultWriteObject(); 267 try { 268 byte[] serialisedBooster = model.toByteArray(); 269 out.writeObject(serialisedBooster); 270 } catch (XGBoostError e) { 271 throw new IOException("Failed to serialize the XGBoost model",e); 272 } 273 } 274 275 private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { 276 in.defaultReadObject(); 277 try { 278 // Now read in the byte array and rebuild the Booster 279 byte[] serialisedBooster = (byte[]) in.readObject(); 280 model = XGBoost.loadModel(new ByteArrayInputStream(serialisedBooster)); 281 } catch (XGBoostError e) { 282 throw new IOException("Failed to deserialize the XGBoost model",e); 283 } 284 } 285}