001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.onnx;
018
019import ai.onnxruntime.OnnxJavaType;
020import ai.onnxruntime.OnnxSequence;
021import ai.onnxruntime.OnnxTensor;
022import ai.onnxruntime.OnnxValue;
023import ai.onnxruntime.OrtException;
024import ai.onnxruntime.SequenceInfo;
025import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
027import org.tribuo.Example;
028import org.tribuo.ImmutableOutputInfo;
029import org.tribuo.Prediction;
030import org.tribuo.classification.Label;
031
032import java.util.ArrayList;
033import java.util.Arrays;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.logging.Logger;
038
039/**
040 * Can convert an {@link OnnxValue} into a {@link Prediction} or a {@link Label}.
041 * <p>
042 *     Accepts both a tuple (tensor,sequence(map(long,float))) and a single tensor.
043 *     The former is usually produced by scikit-learn or similar, the latter is produced by pytorch.
044 * </p>
045 */
046public class LabelTransformer implements OutputTransformer<Label> {
047    private static final long serialVersionUID = 1L;
048    private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName());
049
050    @Override
051    public Prediction<Label> transformToPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int numValidFeatures, Example<Label> example) {
052        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
053        if (predictions.length != 1) {
054            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length);
055        }
056        return generatePrediction(predictions[0],outputIDInfo,numValidFeatures,example);
057    }
058
059    private Prediction<Label> generatePrediction(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo, int numUsed, Example<Label> example) {
060        Label max = null;
061        Map<String,Label> map = new HashMap<>();
062        for (int i = 0; i < predictions.length; i++) {
063            Label current = new Label(outputIDInfo.getOutput(i).getLabel(),predictions[i]);
064            map.put(current.getLabel(),current);
065            if ((max == null) || (current.getScore() > max.getScore())) {
066                max = current;
067            }
068        }
069        return new Prediction<>(max,map,numUsed,example,true);
070    }
071
072    @Override
073    public Label transformToOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo) {
074        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
075        if (predictions.length != 1) {
076            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length);
077        }
078        return generateLabel(predictions[0],outputIDInfo);
079    }
080
081    private Label generateLabel(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo) {
082        int maxIdx = 0;
083        float max = Float.NEGATIVE_INFINITY;
084        for (int i = 0; i < predictions.length; i++) {
085            if (predictions[i] > max) {
086                maxIdx = i;
087                max = predictions[i];
088            }
089        }
090        return new Label(outputIDInfo.getOutput(maxIdx).getLabel(),max);
091    }
092
093    /**
094     * Rationalises the output of an onnx model into a standard format suitable for
095     * downstream work in Tribuo.
096     * @param inputs The onnx model output.
097     * @param outputIDInfo The output id mapping.
098     * @return A 2d array of outputs, the first dimension is batch size, the second dimension is the output space.
099     */
100    private float[][] getBatchPredictions(List<OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) {
101        try {
102            if (inputs.size() == 1) {
103                // Single OnnxTensor [batchSize][numOutputDims]
104                if (inputs.get(0) instanceof OnnxTensor) {
105                    OnnxTensor output = (OnnxTensor) inputs.get(0);
106                    if (output.getInfo().type == OnnxJavaType.FLOAT) {
107                        long[] shape = output.getInfo().getShape();
108                        if ((shape.length == 2) && (shape[1] == outputIDInfo.size())) {
109                            return (float[][]) output.getValue();
110                        } else {
111                            throw new IllegalArgumentException("Invalid shape for the probabilities tensor, expected shape [batchSize,numOutputs], found " + Arrays.toString(shape));
112                        }
113                    } else {
114                        throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + inputs.get(0));
115                    }
116                } else {
117                    throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + inputs.get(0));
118                }
119            } else if (inputs.size() == 2) {
120                // First element is OnnxTensor [batchSize] containing the int predicted label ids, second element is a OnnxSequence<ONNXMap<long,float>>
121                if (inputs.get(1) instanceof OnnxSequence) {
122                    OnnxSequence seq = (OnnxSequence) inputs.get(1);
123                    SequenceInfo info = seq.getInfo();
124                    if ((info.sequenceOfMaps) && (info.mapInfo.keyType == OnnxJavaType.INT64) && (info.mapInfo.valueType == OnnxJavaType.FLOAT)) {
125                        List<?> output = seq.getValue();
126                        float[][] outputArray = new float[output.size()][outputIDInfo.size()];
127                        int i = 0;
128                        for (Object o : output) {
129                            @SuppressWarnings("unchecked") // guarded by the if on the mapInfo above.
130                            Map<Long,Float> map = (Map<Long,Float>) o;
131                            if (map.size() == outputIDInfo.size()) {
132                                for (Map.Entry<Long,Float> e : map.entrySet()) {
133                                    Long key = e.getKey();
134                                    if (key != (int)(long) key) {
135                                        throw new IllegalArgumentException("Key not representable as a Java int, this model is not supported. Expected value less than 2^32, received " + key);
136                                    }
137                                    outputArray[i][(int)(long)key] = e.getValue();
138                                }
139                            } else {
140                                throw new IllegalArgumentException("Expected " + outputIDInfo.size() + " entries in the " + i + "th element, found " + map.size());
141                            }
142                            i++;
143                        }
144                        return outputArray;
145                    } else {
146                        throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + info.toString());
147                    }
148                } else {
149                    throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + inputs.get(1).getInfo().toString());
150                }
151            } else {
152                throw new IllegalArgumentException("Unexpected number of OnnxValues returned, expected 1 or 2, received " + inputs.size());
153            }
154        } catch (OrtException e) {
155            throw new IllegalStateException("Failed to read a value out of the onnx result.",e);
156        }
157    }
158
159    @Override
160    public List<Prediction<Label>> transformToBatchPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int[] numValidFeatures, List<Example<Label>> examples) {
161        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
162        List<Prediction<Label>> output = new ArrayList<>();
163
164        if ((predictions.length != examples.size()) || (predictions.length != numValidFeatures.length)) {
165            throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + numValidFeatures.length + ", received " + predictions.length);
166        }
167
168        for (int i = 0; i < predictions.length; i++) {
169            output.add(generatePrediction(predictions[i],outputIDInfo,numValidFeatures[i],examples.get(i)));
170        }
171
172        return output;
173    }
174
175    @Override
176    public List<Label> transformToBatchOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo) {
177        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
178        List<Label> output = new ArrayList<>();
179
180        for (int i = 0; i < predictions.length; i++) {
181            output.add(generateLabel(predictions[i],outputIDInfo));
182        }
183
184        return output;
185    }
186
187    @Override
188    public boolean generatesProbabilities() {
189        return true;
190    }
191
192    @Override
193    public String toString() {
194        return "LabelTransformer()";
195    }
196
197    @Override
198    public ConfiguredObjectProvenance getProvenance() {
199        return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer");
200    }
201}