001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.onnx; 018 019import ai.onnxruntime.OnnxModelMetadata; 020import ai.onnxruntime.OnnxTensor; 021import ai.onnxruntime.OnnxValue; 022import ai.onnxruntime.OrtEnvironment; 023import ai.onnxruntime.OrtException; 024import ai.onnxruntime.OrtSession; 025import com.oracle.labs.mlrg.olcut.provenance.Provenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 028import com.oracle.labs.mlrg.olcut.util.Pair; 029import org.tribuo.Example; 030import org.tribuo.ImmutableFeatureMap; 031import org.tribuo.ImmutableOutputInfo; 032import org.tribuo.Model; 033import org.tribuo.Output; 034import org.tribuo.OutputFactory; 035import org.tribuo.Prediction; 036import org.tribuo.interop.ExternalDatasetProvenance; 037import org.tribuo.interop.ExternalModel; 038import org.tribuo.interop.ExternalTrainerProvenance; 039import org.tribuo.math.la.SparseVector; 040import org.tribuo.provenance.DatasetProvenance; 041import org.tribuo.provenance.ModelProvenance; 042 043import java.io.IOException; 044import java.net.URL; 045import java.nio.file.Files; 046import java.nio.file.Path; 047import java.nio.file.Paths; 048import java.time.OffsetDateTime; 049import java.util.ArrayList; 050import java.util.Arrays; 051import java.util.Collections; 052import java.util.HashMap; 053import java.util.List; 054import java.util.Map; 055import java.util.logging.Level; 056import java.util.logging.Logger; 057 058/** 059 * A Tribuo wrapper around a ONNX model. 060 * <p> 061 * N.B. ONNX support is experimental, and may change without a major version bump. 062 */ 063public final class ONNXExternalModel<T extends Output<T>> extends ExternalModel<T, OnnxTensor, List<OnnxValue>> implements AutoCloseable { 064 private static final long serialVersionUID = 1L; 065 066 private static final Logger logger = Logger.getLogger(ONNXExternalModel.class.getName()); 067 068 private transient OrtEnvironment env; 069 070 private transient OrtSession.SessionOptions options; 071 072 private transient OrtSession session; 073 074 private final byte[] modelArray; 075 076 private final String inputName; 077 078 private final ExampleTransformer featureTransformer; 079 080 private final OutputTransformer<T> outputTransformer; 081 082 private ONNXExternalModel(String name, ModelProvenance provenance, 083 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 084 Map<String, Integer> featureMapping, 085 byte[] modelArray, OrtSession.SessionOptions options, String inputName, 086 ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer) throws OrtException { 087 super(name, provenance, featureIDMap, outputIDInfo, outputTransformer.generatesProbabilities(), featureMapping); 088 this.modelArray = modelArray; 089 this.options = options; 090 this.inputName = inputName; 091 this.featureTransformer = featureTransformer; 092 this.outputTransformer = outputTransformer; 093 this.env = OrtEnvironment.getEnvironment("tribuo-"+name); 094 this.session = env.createSession(modelArray,options); 095 } 096 097 private ONNXExternalModel(String name, ModelProvenance provenance, 098 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 099 int[] featureForwardMapping, int[] featureBackwardMapping, 100 byte[] modelArray, OrtSession.SessionOptions options, String inputName, 101 ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer) throws OrtException { 102 super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping, 103 outputTransformer.generatesProbabilities()); 104 this.modelArray = modelArray; 105 this.options = options; 106 this.inputName = inputName; 107 this.featureTransformer = featureTransformer; 108 this.outputTransformer = outputTransformer; 109 this.env = OrtEnvironment.getEnvironment("tribuo-"+name); 110 this.session = env.createSession(modelArray,options); 111 } 112 113 /** 114 * Closes the session and rebuilds it using the supplied options. 115 * <p> 116 * Used to select a different backend, or change the number of inference threads etc. 117 * @param newOptions The new session options. 118 * @throws OrtException If the model failed to rebuild the session with the supplied options. 119 */ 120 public synchronized void rebuild(OrtSession.SessionOptions newOptions) throws OrtException { 121 session.close(); 122 if (options != null) { 123 options.close(); 124 } 125 options = newOptions; 126 env.createSession(modelArray,newOptions); 127 } 128 129 @Override 130 protected OnnxTensor convertFeatures(SparseVector input) { 131 try { 132 return featureTransformer.transform(env, input); 133 } catch (OrtException e) { 134 throw new IllegalStateException("Failed to construct input OnnxTensor",e); 135 } 136 } 137 138 @Override 139 protected OnnxTensor convertFeaturesList(List<SparseVector> input) { 140 try { 141 return featureTransformer.transform(env, input); 142 } catch (OrtException e) { 143 throw new IllegalStateException("Failed to construct input OnnxTensor",e); 144 } 145 } 146 147 /** 148 * Runs the session to make a prediction. 149 * <p> 150 * Closes the input tensor after the prediction has been made. 151 * @param input The input in the external model's format. 152 * @return A tensor representing the output. 153 */ 154 @Override 155 protected List<OnnxValue> externalPrediction(OnnxTensor input) { 156 try { 157 // Note the output of the session is closed by the conversion methods, and should not be closed by the result object. 158 OrtSession.Result output = session.run(Collections.singletonMap(inputName,input)); 159 input.close(); 160 ArrayList<OnnxValue> outputs = new ArrayList<>(); 161 for (Map.Entry<String,OnnxValue> v : output) { 162 outputs.add(v.getValue()); 163 } 164 return outputs; 165 } catch (OrtException e) { 166 throw new IllegalStateException("Failed to execute ONNX model",e); 167 } 168 } 169 170 /** 171 * Converts a tensor into a prediction. 172 * Closes the output tensor after it's been converted. 173 * @param output The output of the external model. 174 * @param numValidFeatures The number of valid features in the input. 175 * @param example The input example, used to construct the Prediction. 176 * @return A {@link Prediction} representing this tensor output. 177 */ 178 @Override 179 protected Prediction<T> convertOutput(List<OnnxValue> output, int numValidFeatures, Example<T> example) { 180 Prediction<T> pred = outputTransformer.transformToPrediction(output,outputIDInfo,numValidFeatures,example); 181 OnnxValue.close(output); 182 return pred; 183 } 184 185 /** 186 * Converts a tensor into a prediction. 187 * Closes the output tensor after it's been converted. 188 * @param output The output of the external model. 189 * @param numValidFeatures An array with the number of valid features in each example. 190 * @param examples The input examples, used to construct the Predictions. 191 * @return A list of {@link Prediction} representing this tensor output. 192 */ 193 @Override 194 protected List<Prediction<T>> convertOutput(List<OnnxValue> output, int[] numValidFeatures, List<Example<T>> examples) { 195 List<Prediction<T>> predictions = outputTransformer.transformToBatchPrediction(output,outputIDInfo,numValidFeatures,examples); 196 OnnxValue.close(output); 197 return predictions; 198 } 199 200 @Override 201 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 202 return Collections.emptyMap(); 203 } 204 205 @Override 206 protected synchronized Model<T> copy(String newName, ModelProvenance newProvenance) { 207 byte[] newModelArray = Arrays.copyOf(modelArray,modelArray.length); 208 try { 209 return new ONNXExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo, 210 featureForwardMapping, featureBackwardMapping, 211 newModelArray, options, inputName, featureTransformer, outputTransformer); 212 } catch (OrtException e) { 213 throw new IllegalStateException("Failed to copy ONNX model",e); 214 } 215 } 216 217 @Override 218 public void close() { 219 if (session != null) { 220 try { 221 session.close(); 222 } catch (OrtException e) { 223 logger.log(Level.SEVERE,"Exception thrown when closing session",e); 224 } 225 } 226 if (options != null) { 227 options.close(); 228 } 229 if (env != null) { 230 try { 231 env.close(); 232 } catch (OrtException e) { 233 logger.log(Level.SEVERE,"Exception thrown when closing environment",e); 234 } 235 } 236 } 237 238 /** 239 * Creates an {@code ONNXExternalModel} by loading the model from disk. 240 * @param factory The output factory to use. 241 * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. 242 * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. 243 * @param featureTransformer The transformation function for the features. 244 * @param outputTransformer The transformation function for the outputs. 245 * @param opts The session options for the ONNX model. 246 * @param filename The model path. 247 * @param inputName The name of the input node. 248 * @param <T> The type of the output. 249 * @return An ONNXExternalModel ready to score new inputs. 250 * @throws OrtException If the onnx-runtime native library call failed. 251 */ 252 public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, 253 Map<String, Integer> featureMapping, 254 Map<T,Integer> outputMapping, 255 ExampleTransformer featureTransformer, 256 OutputTransformer<T> outputTransformer, 257 OrtSession.SessionOptions opts, 258 String filename, String inputName) throws OrtException { 259 Path path = Paths.get(filename); 260 return createOnnxModel(factory,featureMapping,outputMapping,featureTransformer,outputTransformer, 261 opts,path,inputName); 262 } 263 264 /** 265 * Creates an {@code ONNXExternalModel} by loading the model from disk. 266 * @param factory The output factory to use. 267 * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids. 268 * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids. 269 * @param featureTransformer The transformation function for the features. 270 * @param outputTransformer The transformation function for the outputs. 271 * @param opts The session options for the ONNX model. 272 * @param path The model path. 273 * @param inputName The name of the input node. 274 * @param <T> The type of the output. 275 * @return An ONNXExternalModel ready to score new inputs. 276 * @throws OrtException If the onnx-runtime native library call failed. 277 */ 278 public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, 279 Map<String, Integer> featureMapping, 280 Map<T,Integer> outputMapping, 281 ExampleTransformer featureTransformer, 282 OutputTransformer<T> outputTransformer, 283 OrtSession.SessionOptions opts, 284 Path path, String inputName) throws OrtException { 285 try { 286 byte[] modelArray = Files.readAllBytes(path); 287 URL provenanceLocation = path.toUri().toURL(); 288 ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet()); 289 ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping); 290 OffsetDateTime now = OffsetDateTime.now(); 291 ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation); 292 DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size()); 293 HashMap<String, Provenance> runProvenance = new HashMap<>(); 294 runProvenance.put("input-name", new StringProvenance("input-name", inputName)); 295 try (OrtEnvironment env = OrtEnvironment.getEnvironment(); 296 OrtSession session = env.createSession(modelArray)) { 297 OnnxModelMetadata metadata = session.getMetadata(); 298 runProvenance.put("model-producer", new StringProvenance("model-producer",metadata.getProducerName())); 299 runProvenance.put("model-domain", new StringProvenance("model-domain",metadata.getDomain())); 300 runProvenance.put("model-description", new StringProvenance("model-description",metadata.getDescription())); 301 runProvenance.put("model-graphname", new StringProvenance("model-graphname",metadata.getGraphName())); 302 runProvenance.put("model-version", new LongProvenance("model-version",metadata.getVersion())); 303 for (Map.Entry<String,String> e : metadata.getCustomMetadata().entrySet()) { 304 String keyName = "model-metadata-"+e.getKey(); 305 runProvenance.put(keyName, new StringProvenance(keyName,e.getValue())); 306 } 307 } catch (OrtException e) { 308 throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e); 309 } 310 ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(),now,datasetProvenance,trainerProvenance,runProvenance); 311 return new ONNXExternalModel<>("external-model",provenance,featureMap,outputInfo, 312 featureMapping,modelArray,opts,inputName,featureTransformer,outputTransformer); 313 } catch (IOException e) { 314 throw new IllegalArgumentException("Unable to load model from path " + path, e); 315 } 316 } 317 318 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 319 in.defaultReadObject(); 320 try { 321 this.env = OrtEnvironment.getEnvironment(); 322 this.options = new OrtSession.SessionOptions(); 323 this.session = env.createSession(modelArray,options); 324 } catch (OrtException e) { 325 throw new IllegalStateException("Could not construct ONNX Runtime session during deserialization."); 326 } 327 } 328 329}