001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.onnx;
018
019import ai.onnxruntime.OnnxTensor;
020import ai.onnxruntime.OrtEnvironment;
021import ai.onnxruntime.OrtException;
022import com.oracle.labs.mlrg.olcut.config.Config;
023import com.oracle.labs.mlrg.olcut.config.PropertyException;
024import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
026import org.tribuo.math.la.SparseVector;
027import org.tribuo.math.la.VectorTuple;
028
029import java.nio.ByteBuffer;
030import java.nio.ByteOrder;
031import java.nio.FloatBuffer;
032import java.util.List;
033
034/**
035 * Image transformer. Assumes the feature id numbers are linearised ids of the form
036 * [0,0,0] = 0, [1,0,0] = 1, ..., [i,0,0] = i, [0,1,0] = i+1, ..., [i,j,0] = i*j, ...
037 * [0,0,1] = (i*j)+1, ..., [i,j,k] = i*j*k.
038 * <p>
039 * ONNX expects images in the format [channels,height,width].
040 */
041public class ImageTransformer implements ExampleTransformer {
042    private static final long serialVersionUID = 1L;
043
044    @Config(mandatory=true,description="Image width.")
045    private int width;
046
047    @Config(mandatory=true,description="Image height.")
048    private int height;
049
050    @Config(mandatory=true,description="Number of channels.")
051    private int channels;
052
053    /**
054     * For olcut.
055     */
056    private ImageTransformer() {}
057
058    public ImageTransformer(int channels, int height, int width) {
059        if (width < 1 || height < 1 || channels < 1) {
060            throw new PropertyException("","Inputs must be positive integers, found [c="+channels+",h="+height+",w="+width+"]");
061        }
062        this.width = width;
063        this.height = height;
064        this.channels = channels;
065    }
066
067    /**
068     * Used by the OLCUT configuration system, and should not be called by external code.
069     */
070    @Override
071    public void postConfig() {
072        if (width < 1 || height < 1 || channels < 1) {
073            throw new PropertyException("","Inputs must be positive integers, found [c="+channels+",h="+height+",w="+width+"]");
074        }
075    }
076
077    /**
078     * Actually performs the transformation. Pads unseen values
079     * with zero. Writes to the buffer in multidimensional row-major form.
080     * @param buffer The buffer to write to.
081     * @param startPos The starting position of the buffer.
082     * @param vector The vector to transform.
083     */
084    private void innerTransform(FloatBuffer buffer, int startPos, SparseVector vector) {
085        for (VectorTuple f : vector) {
086            int id = f.index;
087            buffer.put(id+startPos,(float)f.value);
088        }
089    }
090
091    @Override
092    public OnnxTensor transform(OrtEnvironment env, SparseVector vector) throws OrtException {
093        FloatBuffer buffer = ByteBuffer.allocateDirect(vector.size()*4).order(ByteOrder.nativeOrder()).asFloatBuffer();
094        innerTransform(buffer,0,vector);
095        return OnnxTensor.createTensor(env,buffer,new long[]{1,channels,height,width});
096    }
097
098    @Override
099    public OnnxTensor transform(OrtEnvironment env, List<SparseVector> vectors) throws OrtException {
100        if (vectors.isEmpty()) {
101            return OnnxTensor.createTensor(env,FloatBuffer.allocate(0),new long[]{0,channels,height,width});
102        } else {
103            int initialSize = vectors.get(0).size();
104            FloatBuffer buffer = ByteBuffer.allocateDirect(initialSize * vectors.size() * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
105            int position = 0;
106            for (SparseVector v : vectors) {
107                innerTransform(buffer, position, v);
108                position += v.size();
109                if (v.size() != initialSize) {
110                    throw new IllegalArgumentException("Vectors are not all the same dimension, expected " + initialSize + ", found " + v.size());
111                }
112            }
113            return OnnxTensor.createTensor(env, buffer, new long[]{vectors.size(), channels, height, width});
114        }
115    }
116
117    @Override
118    public String toString() {
119        return "ImageTransformer(channels="+channels+",height="+height+",width="+width+")";
120    }
121
122    @Override
123    public ConfiguredObjectProvenance getProvenance() {
124        return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer");
125    }
126}