001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.onnx;
018
019import ai.onnxruntime.OnnxModelMetadata;
020import ai.onnxruntime.OnnxTensor;
021import ai.onnxruntime.OnnxValue;
022import ai.onnxruntime.OrtEnvironment;
023import ai.onnxruntime.OrtException;
024import ai.onnxruntime.OrtSession;
025import com.oracle.labs.mlrg.olcut.provenance.Provenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
028import com.oracle.labs.mlrg.olcut.util.Pair;
029import org.tribuo.Example;
030import org.tribuo.ImmutableFeatureMap;
031import org.tribuo.ImmutableOutputInfo;
032import org.tribuo.Model;
033import org.tribuo.Output;
034import org.tribuo.OutputFactory;
035import org.tribuo.Prediction;
036import org.tribuo.interop.ExternalDatasetProvenance;
037import org.tribuo.interop.ExternalModel;
038import org.tribuo.interop.ExternalTrainerProvenance;
039import org.tribuo.math.la.SparseVector;
040import org.tribuo.provenance.DatasetProvenance;
041import org.tribuo.provenance.ModelProvenance;
042
043import java.io.IOException;
044import java.net.URL;
045import java.nio.file.Files;
046import java.nio.file.Path;
047import java.nio.file.Paths;
048import java.time.OffsetDateTime;
049import java.util.ArrayList;
050import java.util.Arrays;
051import java.util.Collections;
052import java.util.HashMap;
053import java.util.List;
054import java.util.Map;
055import java.util.logging.Level;
056import java.util.logging.Logger;
057
058/**
059 * A Tribuo wrapper around a ONNX model.
060 * <p>
061 * N.B. ONNX support is experimental, and may change without a major version bump.
062 */
063public final class ONNXExternalModel<T extends Output<T>> extends ExternalModel<T, OnnxTensor, List<OnnxValue>> implements AutoCloseable {
064    private static final long serialVersionUID = 1L;
065
066    private static final Logger logger = Logger.getLogger(ONNXExternalModel.class.getName());
067
068    private transient OrtEnvironment env;
069
070    private transient OrtSession.SessionOptions options;
071
072    private transient OrtSession session;
073
074    private final byte[] modelArray;
075
076    private final String inputName;
077
078    private final ExampleTransformer featureTransformer;
079
080    private final OutputTransformer<T> outputTransformer;
081
082    private ONNXExternalModel(String name, ModelProvenance provenance,
083                              ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
084                              Map<String, Integer> featureMapping,
085                              byte[] modelArray, OrtSession.SessionOptions options, String inputName,
086                              ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer) throws OrtException {
087        super(name, provenance, featureIDMap, outputIDInfo, outputTransformer.generatesProbabilities(), featureMapping);
088        this.modelArray = modelArray;
089        this.options = options;
090        this.inputName = inputName;
091        this.featureTransformer = featureTransformer;
092        this.outputTransformer = outputTransformer;
093        this.env = OrtEnvironment.getEnvironment("tribuo-"+name);
094        this.session = env.createSession(modelArray,options);
095    }
096
097    private ONNXExternalModel(String name, ModelProvenance provenance,
098                              ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
099                              int[] featureForwardMapping, int[] featureBackwardMapping,
100                              byte[] modelArray, OrtSession.SessionOptions options, String inputName,
101                              ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer) throws OrtException {
102        super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping,
103                outputTransformer.generatesProbabilities());
104        this.modelArray = modelArray;
105        this.options = options;
106        this.inputName = inputName;
107        this.featureTransformer = featureTransformer;
108        this.outputTransformer = outputTransformer;
109        this.env = OrtEnvironment.getEnvironment("tribuo-"+name);
110        this.session = env.createSession(modelArray,options);
111    }
112
113    /**
114     * Closes the session and rebuilds it using the supplied options.
115     * <p>
116     * Used to select a different backend, or change the number of inference threads etc.
117     * @param newOptions The new session options.
118     * @throws OrtException If the model failed to rebuild the session with the supplied options.
119     */
120    public synchronized void rebuild(OrtSession.SessionOptions newOptions) throws OrtException {
121        session.close();
122        if (options != null) {
123            options.close();
124        }
125        options = newOptions;
126        env.createSession(modelArray,newOptions);
127    }
128
129    @Override
130    protected OnnxTensor convertFeatures(SparseVector input) {
131        try {
132            return featureTransformer.transform(env, input);
133        } catch (OrtException e) {
134            throw new IllegalStateException("Failed to construct input OnnxTensor",e);
135        }
136    }
137
138    @Override
139    protected OnnxTensor convertFeaturesList(List<SparseVector> input) {
140        try {
141            return featureTransformer.transform(env, input);
142        } catch (OrtException e) {
143            throw new IllegalStateException("Failed to construct input OnnxTensor",e);
144        }
145    }
146
147    /**
148     * Runs the session to make a prediction.
149     * <p>
150     * Closes the input tensor after the prediction has been made.
151     * @param input The input in the external model's format.
152     * @return A tensor representing the output.
153     */
154    @Override
155    protected List<OnnxValue> externalPrediction(OnnxTensor input) {
156        try {
157            // Note the output of the session is closed by the conversion methods, and should not be closed by the result object.
158            OrtSession.Result output = session.run(Collections.singletonMap(inputName,input));
159            input.close();
160            ArrayList<OnnxValue> outputs = new ArrayList<>();
161            for (Map.Entry<String,OnnxValue> v : output) {
162                outputs.add(v.getValue());
163            }
164            return outputs;
165        } catch (OrtException e) {
166            throw new IllegalStateException("Failed to execute ONNX model",e);
167        }
168    }
169
170    /**
171     * Converts a tensor into a prediction.
172     * Closes the output tensor after it's been converted.
173     * @param output The output of the external model.
174     * @param numValidFeatures The number of valid features in the input.
175     * @param example The input example, used to construct the Prediction.
176     * @return A {@link Prediction} representing this tensor output.
177     */
178    @Override
179    protected Prediction<T> convertOutput(List<OnnxValue> output, int numValidFeatures, Example<T> example) {
180        Prediction<T> pred = outputTransformer.transformToPrediction(output,outputIDInfo,numValidFeatures,example);
181        OnnxValue.close(output);
182        return pred;
183    }
184
185    /**
186     * Converts a tensor into a prediction.
187     * Closes the output tensor after it's been converted.
188     * @param output The output of the external model.
189     * @param numValidFeatures An array with the number of valid features in each example.
190     * @param examples The input examples, used to construct the Predictions.
191     * @return A list of {@link Prediction} representing this tensor output.
192     */
193    @Override
194    protected List<Prediction<T>> convertOutput(List<OnnxValue> output, int[] numValidFeatures, List<Example<T>> examples) {
195        List<Prediction<T>> predictions = outputTransformer.transformToBatchPrediction(output,outputIDInfo,numValidFeatures,examples);
196        OnnxValue.close(output);
197        return predictions;
198    }
199
200    @Override
201    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
202        return Collections.emptyMap();
203    }
204
205    @Override
206    protected synchronized Model<T> copy(String newName, ModelProvenance newProvenance) {
207        byte[] newModelArray = Arrays.copyOf(modelArray,modelArray.length);
208        try {
209            return new ONNXExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo,
210                    featureForwardMapping, featureBackwardMapping,
211                    newModelArray, options, inputName, featureTransformer, outputTransformer);
212        } catch (OrtException e) {
213            throw new IllegalStateException("Failed to copy ONNX model",e);
214        }
215    }
216
217    @Override
218    public void close() {
219        if (session != null) {
220            try {
221                session.close();
222            } catch (OrtException e) {
223                logger.log(Level.SEVERE,"Exception thrown when closing session",e);
224            }
225        }
226        if (options != null) {
227            options.close();
228        }
229        if (env != null) {
230            try {
231                env.close();
232            } catch (OrtException e) {
233                logger.log(Level.SEVERE,"Exception thrown when closing environment",e);
234            }
235        }
236    }
237
238    /**
239     * Creates an {@code ONNXExternalModel} by loading the model from disk.
240     * @param factory The output factory to use.
241     * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids.
242     * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids.
243     * @param featureTransformer The transformation function for the features.
244     * @param outputTransformer The transformation function for the outputs.
245     * @param opts The session options for the ONNX model.
246     * @param filename The model path.
247     * @param inputName The name of the input node.
248     * @param <T> The type of the output.
249     * @return An ONNXExternalModel ready to score new inputs.
250     * @throws OrtException If the onnx-runtime native library call failed.
251     */
252    public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory,
253                                                                             Map<String, Integer> featureMapping,
254                                                                             Map<T,Integer> outputMapping,
255                                                                             ExampleTransformer featureTransformer,
256                                                                             OutputTransformer<T> outputTransformer,
257                                                                             OrtSession.SessionOptions opts,
258                                                                             String filename, String inputName) throws OrtException {
259        Path path = Paths.get(filename);
260        return createOnnxModel(factory,featureMapping,outputMapping,featureTransformer,outputTransformer,
261                opts,path,inputName);
262    }
263
264    /**
265     * Creates an {@code ONNXExternalModel} by loading the model from disk.
266     * @param factory The output factory to use.
267     * @param featureMapping The feature mapping between Tribuo names and ONNX integer ids.
268     * @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids.
269     * @param featureTransformer The transformation function for the features.
270     * @param outputTransformer The transformation function for the outputs.
271     * @param opts The session options for the ONNX model.
272     * @param path The model path.
273     * @param inputName The name of the input node.
274     * @param <T> The type of the output.
275     * @return An ONNXExternalModel ready to score new inputs.
276     * @throws OrtException If the onnx-runtime native library call failed.
277     */
278    public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory,
279                                                                             Map<String, Integer> featureMapping,
280                                                                             Map<T,Integer> outputMapping,
281                                                                             ExampleTransformer featureTransformer,
282                                                                             OutputTransformer<T> outputTransformer,
283                                                                             OrtSession.SessionOptions opts,
284                                                                             Path path, String inputName) throws OrtException {
285        try {
286            byte[] modelArray = Files.readAllBytes(path);
287            URL provenanceLocation = path.toUri().toURL();
288            ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
289            ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping);
290            OffsetDateTime now = OffsetDateTime.now();
291            ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
292            DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size());
293            HashMap<String, Provenance> runProvenance = new HashMap<>();
294            runProvenance.put("input-name", new StringProvenance("input-name", inputName));
295            try (OrtEnvironment env = OrtEnvironment.getEnvironment();
296                 OrtSession session = env.createSession(modelArray)) {
297                OnnxModelMetadata metadata = session.getMetadata();
298                runProvenance.put("model-producer", new StringProvenance("model-producer",metadata.getProducerName()));
299                runProvenance.put("model-domain", new StringProvenance("model-domain",metadata.getDomain()));
300                runProvenance.put("model-description", new StringProvenance("model-description",metadata.getDescription()));
301                runProvenance.put("model-graphname", new StringProvenance("model-graphname",metadata.getGraphName()));
302                runProvenance.put("model-version", new LongProvenance("model-version",metadata.getVersion()));
303                for (Map.Entry<String,String> e : metadata.getCustomMetadata().entrySet()) {
304                    String keyName = "model-metadata-"+e.getKey();
305                    runProvenance.put(keyName, new StringProvenance(keyName,e.getValue()));
306                }
307            } catch (OrtException e) {
308                throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
309            }
310            ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(),now,datasetProvenance,trainerProvenance,runProvenance);
311            return new ONNXExternalModel<>("external-model",provenance,featureMap,outputInfo,
312                    featureMapping,modelArray,opts,inputName,featureTransformer,outputTransformer);
313        } catch (IOException e) {
314            throw new IllegalArgumentException("Unable to load model from path " + path, e);
315        }
316    }
317
318    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
319        in.defaultReadObject();
320        try {
321            this.env = OrtEnvironment.getEnvironment();
322            this.options = new OrtSession.SessionOptions();
323            this.session = env.createSession(modelArray,options);
324        } catch (OrtException e) {
325            throw new IllegalStateException("Could not construct ONNX Runtime session during deserialization.");
326        }
327    }
328
329}