001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
023import org.tribuo.Example;
024import org.tribuo.Feature;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.Output;
027import org.tribuo.math.la.SparseVector;
028import org.tribuo.math.la.VectorTuple;
029import org.tensorflow.Tensor;
030
031import java.util.List;
032
033/**
034 * Image transformer. Assumes the feature id numbers are linearised ids of the form
035 * [0,0,0] = 0, [1,0,0] = 1, ..., [i,0,0] = i, [0,1,0] = i+1, ..., [i,j,0] = i*j, ...
036 * [0,0,1] = (i*j)+1, ..., [i,j,k] = i*j*k.
037 */
038public class ImageTransformer<T extends Output<T>> implements ExampleTransformer<T> {
039    private static final long serialVersionUID = 1L;
040
041    @Config(mandatory=true,description="Image width.")
042    private int width;
043
044    @Config(mandatory=true,description="Image height.")
045    private int height;
046
047    @Config(mandatory=true,description="Number of channels.")
048    private int channels;
049
050    /**
051     * For olcut.
052     */
053    private ImageTransformer() {}
054
055    public ImageTransformer(int width, int height, int channels) {
056        if (width < 1 || height < 1 || channels < 1) {
057            throw new IllegalArgumentException("Inputs must be positive integers, found ["+width+","+height+","+channels+"]");
058        }
059        this.width = width;
060        this.height = height;
061        this.channels = channels;
062    }
063
064    /**
065     * Used by the OLCUT configuration system, and should not be called by external code.
066     */
067    @Override
068    public void postConfig() {
069        if (width < 1 || height < 1 || channels < 1) {
070            throw new PropertyException("","Inputs must be positive integers, found ["+width+","+height+","+channels+"]");
071        }
072    }
073
074    /**
075     * Transform implicitly pads unseen values with zero.
076     * @param example The example to transform.
077     * @param featureIDMap The feature id mapping to use.
078     * @return A 3d tensor, (width, height, channels) for this example.
079     */
080    @Override
081    public Tensor<?> transform(Example<T> example, ImmutableFeatureMap featureIDMap) {
082        float[][][][] image = new float[1][][][];
083        image[0] = innerTransform(example,featureIDMap);
084        return Tensor.create(image);
085    }
086
087    /**
088     * Actually performs the transformation. Implicitly pads unseen values
089     * with zero.
090     * @param example The example to transform.
091     * @param featureIDMap The feature id mapping to use.
092     * @return A 3d array, (width,height,channels) representing the example.
093     */
094    float[][][] innerTransform(Example<T> example, ImmutableFeatureMap featureIDMap) {
095        float[][][] image = new float[width][height][channels];
096
097        for (Feature f : example) {
098            int id = featureIDMap.getID(f.getName());
099            int curWidth = id % width;
100            int curHeight = (id / width) % height;
101            int curChannel = id / (width * height);
102            image[curWidth][curHeight][curChannel] = (float) f.getValue();
103        }
104
105        return image;
106    }
107
108    /**
109     * Actually performs the transformation. Implicitly pads unseen values
110     * with zero.
111     * @param vector The vector to transform.
112     * @return A 3d array, (width,height,channels) representing the vector.
113     */
114    float[][][] innerTransform(SparseVector vector) {
115        float[][][] image = new float[width][height][channels];
116
117        for (VectorTuple f : vector) {
118            int id = f.index;
119            int curWidth = id % width;
120            int curHeight = (id / width) % height;
121            int curChannel = id / (width * height);
122            image[curWidth][curHeight][curChannel] = (float) f.value;
123        }
124
125        return image;
126    }
127
128    /**
129     * Transform implicitly pads unseen values with zero.
130     * <p>
131     * Converts a batch of examples into a Tensor.
132     * @param examples The examples to transform.
133     * @param featureIDMap The feature id mapping to use.
134     * @return A 4d tensor, (batch-id, width, height, channels) for this example.
135     */
136    @Override
137    public Tensor<?> transform(List<Example<T>> examples, ImmutableFeatureMap featureIDMap) {
138        float[][][][] image = new float[examples.size()][][][];
139
140        int i = 0;
141        for (Example<T> example : examples) {
142            image[i] = innerTransform(example,featureIDMap);
143            i++;
144        }
145
146        return Tensor.create(image);
147    }
148
149    @Override
150    public Tensor<?> transform(SparseVector vector) {
151        float[][][][] image = new float[1][][][];
152        image[0] = innerTransform(vector);
153        return Tensor.create(image);
154    }
155
156    @Override
157    public Tensor<?> transform(List<SparseVector> vectors) {
158        float[][][][] image = new float[vectors.size()][][][];
159
160        int i = 0;
161        for (SparseVector vector : vectors) {
162            image[i] = innerTransform(vector);
163            i++;
164        }
165
166        return Tensor.create(image);
167    }
168
169    @Override
170    public String toString() {
171        return "ImageTransformer(width="+width+",height="+height+",channels="+channels+")";
172    }
173
174    @Override
175    public ConfiguredObjectProvenance getProvenance() {
176        return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer");
177    }
178}