001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.onnx; 018 019import ai.onnxruntime.OnnxTensor; 020import ai.onnxruntime.OrtEnvironment; 021import ai.onnxruntime.OrtException; 022import com.oracle.labs.mlrg.olcut.config.Config; 023import com.oracle.labs.mlrg.olcut.config.PropertyException; 024import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 026import org.tribuo.math.la.SparseVector; 027import org.tribuo.math.la.VectorTuple; 028 029import java.nio.ByteBuffer; 030import java.nio.ByteOrder; 031import java.nio.FloatBuffer; 032import java.util.List; 033 034/** 035 * Image transformer. Assumes the feature id numbers are linearised ids of the form 036 * [0,0,0] = 0, [1,0,0] = 1, ..., [i,0,0] = i, [0,1,0] = i+1, ..., [i,j,0] = i*j, ... 037 * [0,0,1] = (i*j)+1, ..., [i,j,k] = i*j*k. 038 * <p> 039 * ONNX expects images in the format [channels,height,width]. 040 */ 041public class ImageTransformer implements ExampleTransformer { 042 private static final long serialVersionUID = 1L; 043 044 @Config(mandatory=true,description="Image width.") 045 private int width; 046 047 @Config(mandatory=true,description="Image height.") 048 private int height; 049 050 @Config(mandatory=true,description="Number of channels.") 051 private int channels; 052 053 /** 054 * For olcut. 055 */ 056 private ImageTransformer() {} 057 058 public ImageTransformer(int channels, int height, int width) { 059 if (width < 1 || height < 1 || channels < 1) { 060 throw new PropertyException("","Inputs must be positive integers, found [c="+channels+",h="+height+",w="+width+"]"); 061 } 062 this.width = width; 063 this.height = height; 064 this.channels = channels; 065 } 066 067 /** 068 * Used by the OLCUT configuration system, and should not be called by external code. 069 */ 070 @Override 071 public void postConfig() { 072 if (width < 1 || height < 1 || channels < 1) { 073 throw new PropertyException("","Inputs must be positive integers, found [c="+channels+",h="+height+",w="+width+"]"); 074 } 075 } 076 077 /** 078 * Actually performs the transformation. Pads unseen values 079 * with zero. Writes to the buffer in multidimensional row-major form. 080 * @param buffer The buffer to write to. 081 * @param startPos The starting position of the buffer. 082 * @param vector The vector to transform. 083 */ 084 private void innerTransform(FloatBuffer buffer, int startPos, SparseVector vector) { 085 for (VectorTuple f : vector) { 086 int id = f.index; 087 buffer.put(id+startPos,(float)f.value); 088 } 089 } 090 091 @Override 092 public OnnxTensor transform(OrtEnvironment env, SparseVector vector) throws OrtException { 093 FloatBuffer buffer = ByteBuffer.allocateDirect(vector.size()*4).order(ByteOrder.nativeOrder()).asFloatBuffer(); 094 innerTransform(buffer,0,vector); 095 return OnnxTensor.createTensor(env,buffer,new long[]{1,channels,height,width}); 096 } 097 098 @Override 099 public OnnxTensor transform(OrtEnvironment env, List<SparseVector> vectors) throws OrtException { 100 if (vectors.isEmpty()) { 101 return OnnxTensor.createTensor(env,FloatBuffer.allocate(0),new long[]{0,channels,height,width}); 102 } else { 103 int initialSize = vectors.get(0).size(); 104 FloatBuffer buffer = ByteBuffer.allocateDirect(initialSize * vectors.size() * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); 105 int position = 0; 106 for (SparseVector v : vectors) { 107 innerTransform(buffer, position, v); 108 position += v.size(); 109 if (v.size() != initialSize) { 110 throw new IllegalArgumentException("Vectors are not all the same dimension, expected " + initialSize + ", found " + v.size()); 111 } 112 } 113 return OnnxTensor.createTensor(env, buffer, new long[]{vectors.size(), channels, height, width}); 114 } 115 } 116 117 @Override 118 public String toString() { 119 return "ImageTransformer(channels="+channels+",height="+height+",width="+width+")"; 120 } 121 122 @Override 123 public ConfiguredObjectProvenance getProvenance() { 124 return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer"); 125 } 126}