001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.onnx; 018 019import ai.onnxruntime.OnnxJavaType; 020import ai.onnxruntime.OnnxSequence; 021import ai.onnxruntime.OnnxTensor; 022import ai.onnxruntime.OnnxValue; 023import ai.onnxruntime.OrtException; 024import ai.onnxruntime.SequenceInfo; 025import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 027import org.tribuo.Example; 028import org.tribuo.ImmutableOutputInfo; 029import org.tribuo.Prediction; 030import org.tribuo.classification.Label; 031 032import java.util.ArrayList; 033import java.util.Arrays; 034import java.util.HashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.logging.Logger; 038 039/** 040 * Can convert an {@link OnnxValue} into a {@link Prediction} or a {@link Label}. 041 * <p> 042 * Accepts both a tuple (tensor,sequence(map(long,float))) and a single tensor. 043 * The former is usually produced by scikit-learn or similar, the latter is produced by pytorch. 044 * </p> 045 */ 046public class LabelTransformer implements OutputTransformer<Label> { 047 private static final long serialVersionUID = 1L; 048 private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName()); 049 050 @Override 051 public Prediction<Label> transformToPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int numValidFeatures, Example<Label> example) { 052 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 053 if (predictions.length != 1) { 054 throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length); 055 } 056 return generatePrediction(predictions[0],outputIDInfo,numValidFeatures,example); 057 } 058 059 private Prediction<Label> generatePrediction(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo, int numUsed, Example<Label> example) { 060 Label max = null; 061 Map<String,Label> map = new HashMap<>(); 062 for (int i = 0; i < predictions.length; i++) { 063 Label current = new Label(outputIDInfo.getOutput(i).getLabel(),predictions[i]); 064 map.put(current.getLabel(),current); 065 if ((max == null) || (current.getScore() > max.getScore())) { 066 max = current; 067 } 068 } 069 return new Prediction<>(max,map,numUsed,example,true); 070 } 071 072 @Override 073 public Label transformToOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo) { 074 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 075 if (predictions.length != 1) { 076 throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length); 077 } 078 return generateLabel(predictions[0],outputIDInfo); 079 } 080 081 private Label generateLabel(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo) { 082 int maxIdx = 0; 083 float max = Float.NEGATIVE_INFINITY; 084 for (int i = 0; i < predictions.length; i++) { 085 if (predictions[i] > max) { 086 maxIdx = i; 087 max = predictions[i]; 088 } 089 } 090 return new Label(outputIDInfo.getOutput(maxIdx).getLabel(),max); 091 } 092 093 /** 094 * Rationalises the output of an onnx model into a standard format suitable for 095 * downstream work in Tribuo. 096 * @param inputs The onnx model output. 097 * @param outputIDInfo The output id mapping. 098 * @return A 2d array of outputs, the first dimension is batch size, the second dimension is the output space. 099 */ 100 private float[][] getBatchPredictions(List<OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) { 101 try { 102 if (inputs.size() == 1) { 103 // Single OnnxTensor [batchSize][numOutputDims] 104 if (inputs.get(0) instanceof OnnxTensor) { 105 OnnxTensor output = (OnnxTensor) inputs.get(0); 106 if (output.getInfo().type == OnnxJavaType.FLOAT) { 107 long[] shape = output.getInfo().getShape(); 108 if ((shape.length == 2) && (shape[1] == outputIDInfo.size())) { 109 return (float[][]) output.getValue(); 110 } else { 111 throw new IllegalArgumentException("Invalid shape for the probabilities tensor, expected shape [batchSize,numOutputs], found " + Arrays.toString(shape)); 112 } 113 } else { 114 throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + inputs.get(0)); 115 } 116 } else { 117 throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + inputs.get(0)); 118 } 119 } else if (inputs.size() == 2) { 120 // First element is OnnxTensor [batchSize] containing the int predicted label ids, second element is a OnnxSequence<ONNXMap<long,float>> 121 if (inputs.get(1) instanceof OnnxSequence) { 122 OnnxSequence seq = (OnnxSequence) inputs.get(1); 123 SequenceInfo info = seq.getInfo(); 124 if ((info.sequenceOfMaps) && (info.mapInfo.keyType == OnnxJavaType.INT64) && (info.mapInfo.valueType == OnnxJavaType.FLOAT)) { 125 List<?> output = seq.getValue(); 126 float[][] outputArray = new float[output.size()][outputIDInfo.size()]; 127 int i = 0; 128 for (Object o : output) { 129 @SuppressWarnings("unchecked") // guarded by the if on the mapInfo above. 130 Map<Long,Float> map = (Map<Long,Float>) o; 131 if (map.size() == outputIDInfo.size()) { 132 for (Map.Entry<Long,Float> e : map.entrySet()) { 133 Long key = e.getKey(); 134 if (key != (int)(long) key) { 135 throw new IllegalArgumentException("Key not representable as a Java int, this model is not supported. Expected value less than 2^32, received " + key); 136 } 137 outputArray[i][(int)(long)key] = e.getValue(); 138 } 139 } else { 140 throw new IllegalArgumentException("Expected " + outputIDInfo.size() + " entries in the " + i + "th element, found " + map.size()); 141 } 142 i++; 143 } 144 return outputArray; 145 } else { 146 throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + info.toString()); 147 } 148 } else { 149 throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + inputs.get(1).getInfo().toString()); 150 } 151 } else { 152 throw new IllegalArgumentException("Unexpected number of OnnxValues returned, expected 1 or 2, received " + inputs.size()); 153 } 154 } catch (OrtException e) { 155 throw new IllegalStateException("Failed to read a value out of the onnx result.",e); 156 } 157 } 158 159 @Override 160 public List<Prediction<Label>> transformToBatchPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int[] numValidFeatures, List<Example<Label>> examples) { 161 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 162 List<Prediction<Label>> output = new ArrayList<>(); 163 164 if ((predictions.length != examples.size()) || (predictions.length != numValidFeatures.length)) { 165 throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + numValidFeatures.length + ", received " + predictions.length); 166 } 167 168 for (int i = 0; i < predictions.length; i++) { 169 output.add(generatePrediction(predictions[i],outputIDInfo,numValidFeatures[i],examples.get(i))); 170 } 171 172 return output; 173 } 174 175 @Override 176 public List<Label> transformToBatchOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo) { 177 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 178 List<Label> output = new ArrayList<>(); 179 180 for (int i = 0; i < predictions.length; i++) { 181 output.add(generateLabel(predictions[i],outputIDInfo)); 182 } 183 184 return output; 185 } 186 187 @Override 188 public boolean generatesProbabilities() { 189 return true; 190 } 191 192 @Override 193 public String toString() { 194 return "LabelTransformer()"; 195 } 196 197 @Override 198 public ConfiguredObjectProvenance getProvenance() { 199 return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer"); 200 } 201}