001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.onnx; 018 019import ai.onnxruntime.OnnxJavaType; 020import ai.onnxruntime.OnnxTensor; 021import ai.onnxruntime.OnnxValue; 022import ai.onnxruntime.OrtException; 023import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 025import com.oracle.labs.mlrg.olcut.util.Pair; 026import org.tribuo.Example; 027import org.tribuo.ImmutableOutputInfo; 028import org.tribuo.Prediction; 029import org.tribuo.regression.Regressor; 030 031import java.util.ArrayList; 032import java.util.Arrays; 033import java.util.List; 034 035/** 036 * Can convert an {@link OnnxValue} into a {@link Prediction} or {@link Regressor}. 037 */ 038public class RegressorTransformer implements OutputTransformer<Regressor> { 039 private static final long serialVersionUID = 1L; 040 041 @Override 042 public Prediction<Regressor> transformToPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int numValidFeatures, Example<Regressor> example) { 043 Regressor r = transformToOutput(tensor,outputIDInfo); 044 return new Prediction<>(r,numValidFeatures,example); 045 } 046 047 @Override 048 public Regressor transformToOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo) { 049 float[][] predictions = getBatchPredictions(tensor); 050 if (predictions.length != 1) { 051 throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length); 052 } else if (predictions[0].length != outputIDInfo.size()) { 053 throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + predictions[0].length + ", expected " + outputIDInfo.size()); 054 } 055 String[] names = new String[outputIDInfo.size()]; 056 double[] values = new double[outputIDInfo.size()]; 057 for (Pair<Integer,Regressor> p : outputIDInfo) { 058 int id = p.getA(); 059 names[id] = p.getB().getNames()[0]; 060 values[id] = predictions[0][id]; 061 } 062 return new Regressor(names,values); 063 } 064 065 private float[][] getBatchPredictions(List<OnnxValue> valueList) { 066 if (valueList.size() != 1) { 067 throw new IllegalArgumentException("Supplied output has incorrect number of elements, expected 1, found " + valueList.size()); 068 } 069 070 OnnxValue value = valueList.get(0); 071 if (value instanceof OnnxTensor) { 072 OnnxTensor tensor = (OnnxTensor) value; 073 long[] shape = tensor.getInfo().getShape(); 074 if (shape.length != 2) { 075 throw new IllegalArgumentException("Expected shape [batchSize][numDimensions], found " + Arrays.toString(shape)); 076 } else { 077 try { 078 if (tensor.getInfo().type == OnnxJavaType.FLOAT) { 079 // Will return a float array 080 return (float[][]) tensor.getValue(); 081 } else { 082 throw new IllegalArgumentException("Supplied output was an invalid tensor type, expected float, found " + tensor.getInfo().type); 083 } 084 } catch (OrtException e) { 085 throw new IllegalStateException("Failed to read tensor value",e); 086 } 087 } 088 } else { 089 throw new IllegalArgumentException("Supplied output was not an OnnxTensor, found " + value.getClass().toString()); 090 } 091 } 092 093 @Override 094 public List<Prediction<Regressor>> transformToBatchPrediction(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo, int[] numValidFeatures, List<Example<Regressor>> examples) { 095 List<Regressor> regressors = transformToBatchOutput(tensor,outputIDInfo); 096 List<Prediction<Regressor>> output = new ArrayList<>(); 097 098 if ((regressors.size() != examples.size()) || (regressors.size() != numValidFeatures.length)) { 099 throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + numValidFeatures.length + ", received " + regressors.size()); 100 } 101 102 for (int i = 0; i < regressors.size(); i++) { 103 output.add(new Prediction<>(regressors.get(i),numValidFeatures[i],examples.get(i))); 104 } 105 106 return output; 107 } 108 109 @Override 110 public List<Regressor> transformToBatchOutput(List<OnnxValue> tensor, ImmutableOutputInfo<Regressor> outputIDInfo) { 111 float[][] predictions = getBatchPredictions(tensor); 112 List<Regressor> output = new ArrayList<>(); 113 114 String[] names = new String[outputIDInfo.size()]; 115 for (Pair<Integer,Regressor> p : outputIDInfo) { 116 int id = p.getA(); 117 names[id] = p.getB().getNames()[0]; 118 } 119 for (int i = 0; i < predictions.length; i++) { 120 double[] values = new double[names.length]; 121 for (int j = 0; j < names.length; j++) { 122 values[j] = predictions[i][j]; 123 } 124 output.add(new Regressor(names,values)); 125 } 126 127 return output; 128 } 129 130 @Override 131 public boolean generatesProbabilities() { 132 return false; 133 } 134 135 @Override 136 public String toString() { 137 return "RegressorTransformer()"; 138 } 139 140 @Override 141 public ConfiguredObjectProvenance getProvenance() { 142 return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer"); 143 } 144}