001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.onnx;
018
019import ai.onnxruntime.OnnxJavaType;
020import ai.onnxruntime.OnnxTensor;
021import ai.onnxruntime.OnnxValue;
022import ai.onnxruntime.OrtException;
023import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
025import com.oracle.labs.mlrg.olcut.util.Pair;
026import org.tribuo.Example;
027import org.tribuo.ImmutableOutputInfo;
028import org.tribuo.Prediction;
029import org.tribuo.regression.Regressor;
030
031import java.util.ArrayList;
032import java.util.Arrays;
033import java.util.List;
034
035/**
036 * Can convert an {@link OnnxValue} into a {@link Prediction} or {@link Regressor}.
037 */
038public class RegressorTransformer implements OutputTransformer<Regressor> {
039    private static final long serialVersionUID = 1L;
040
041    @Override
042    public Prediction<Regressor> transformToPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int numValidFeatures, Example<Regressor> example) {
043        Regressor r = transformToOutput(tensor,outputIDInfo);
044        return new Prediction<>(r,numValidFeatures,example);
045    }
046
047    @Override
048    public Regressor transformToOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo) {
049        float[][] predictions = getBatchPredictions(tensor);
050        if (predictions.length != 1) {
051            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length);
052        } else if (predictions[0].length != outputIDInfo.size()) {
053            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + predictions[0].length + ", expected " + outputIDInfo.size());
054        }
055        String[] names = new String[outputIDInfo.size()];
056        double[] values = new double[outputIDInfo.size()];
057        for (Pair<Integer,Regressor> p : outputIDInfo) {
058            int id = p.getA();
059            names[id] = p.getB().getNames()[0];
060            values[id] = predictions[0][id];
061        }
062        return new Regressor(names,values);
063    }
064
065    private float[][] getBatchPredictions(List<OnnxValue> valueList) {
066        if (valueList.size() != 1) {
067            throw new IllegalArgumentException("Supplied output has incorrect number of elements, expected 1, found " + valueList.size());
068        }
069
070        OnnxValue value = valueList.get(0);
071        if (value instanceof OnnxTensor) {
072            OnnxTensor tensor = (OnnxTensor) value;
073            long[] shape = tensor.getInfo().getShape();
074            if (shape.length != 2) {
075                throw new IllegalArgumentException("Expected shape [batchSize][numDimensions], found " + Arrays.toString(shape));
076            } else {
077                try {
078                    if (tensor.getInfo().type == OnnxJavaType.FLOAT) {
079                        // Will return a float array
080                        return (float[][]) tensor.getValue();
081                    } else {
082                        throw new IllegalArgumentException("Supplied output was an invalid tensor type, expected float, found " + tensor.getInfo().type);
083                    }
084                } catch (OrtException e) {
085                    throw new IllegalStateException("Failed to read tensor value",e);
086                }
087            }
088        } else {
089            throw new IllegalArgumentException("Supplied output was not an OnnxTensor, found " + value.getClass().toString());
090        }
091    }
092
093    @Override
094    public List<Prediction<Regressor>> transformToBatchPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int[] numValidFeatures, List<Example<Regressor>> examples) {
095        List<Regressor> regressors = transformToBatchOutput(tensor,outputIDInfo);
096        List<Prediction<Regressor>> output = new ArrayList<>();
097
098        if ((regressors.size() != examples.size()) || (regressors.size() != numValidFeatures.length)) {
099            throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + numValidFeatures.length + ", received " + regressors.size());
100        }
101
102        for (int i = 0; i < regressors.size(); i++) {
103            output.add(new Prediction<>(regressors.get(i),numValidFeatures[i],examples.get(i)));
104        }
105
106        return output;
107    }
108
109    @Override
110    public List<Regressor> transformToBatchOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo) {
111        float[][] predictions = getBatchPredictions(tensor);
112        List<Regressor> output = new ArrayList<>();
113
114        String[] names = new String[outputIDInfo.size()];
115        for (Pair<Integer,Regressor> p : outputIDInfo) {
116            int id = p.getA();
117            names[id] = p.getB().getNames()[0];
118        }
119        for (int i = 0; i < predictions.length; i++) {
120            double[] values = new double[names.length];
121            for (int j = 0; j < names.length; j++) {
122                values[j] = predictions[i][j];
123            }
124            output.add(new Regressor(names,values));
125        }
126
127        return output;
128    }
129
130    @Override
131    public boolean generatesProbabilities() {
132        return false;
133    }
134
135    @Override
136    public String toString() {
137        return "RegressorTransformer()";
138    }
139
140    @Override
141    public ConfiguredObjectProvenance getProvenance() {
142        return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer");
143    }
144}