001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 023import org.tribuo.Example; 024import org.tribuo.Feature; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.Output; 027import org.tribuo.math.la.SparseVector; 028import org.tribuo.math.la.VectorTuple; 029import org.tensorflow.Tensor; 030 031import java.util.List; 032 033/** 034 * Image transformer. Assumes the feature id numbers are linearised ids of the form 035 * [0,0,0] = 0, [1,0,0] = 1, ..., [i,0,0] = i, [0,1,0] = i+1, ..., [i,j,0] = i*j, ... 036 * [0,0,1] = (i*j)+1, ..., [i,j,k] = i*j*k. 037 */ 038public class ImageTransformer<T extends Output<T>> implements ExampleTransformer<T> { 039 private static final long serialVersionUID = 1L; 040 041 @Config(mandatory=true,description="Image width.") 042 private int width; 043 044 @Config(mandatory=true,description="Image height.") 045 private int height; 046 047 @Config(mandatory=true,description="Number of channels.") 048 private int channels; 049 050 /** 051 * For olcut. 052 */ 053 private ImageTransformer() {} 054 055 public ImageTransformer(int width, int height, int channels) { 056 if (width < 1 || height < 1 || channels < 1) { 057 throw new IllegalArgumentException("Inputs must be positive integers, found ["+width+","+height+","+channels+"]"); 058 } 059 this.width = width; 060 this.height = height; 061 this.channels = channels; 062 } 063 064 /** 065 * Used by the OLCUT configuration system, and should not be called by external code. 066 */ 067 @Override 068 public void postConfig() { 069 if (width < 1 || height < 1 || channels < 1) { 070 throw new PropertyException("","Inputs must be positive integers, found ["+width+","+height+","+channels+"]"); 071 } 072 } 073 074 /** 075 * Transform implicitly pads unseen values with zero. 076 * @param example The example to transform. 077 * @param featureIDMap The feature id mapping to use. 078 * @return A 3d tensor, (width, height, channels) for this example. 079 */ 080 @Override 081 public Tensor<?> transform(Example<T> example, ImmutableFeatureMap featureIDMap) { 082 float[][][][] image = new float[1][][][]; 083 image[0] = innerTransform(example,featureIDMap); 084 return Tensor.create(image); 085 } 086 087 /** 088 * Actually performs the transformation. Implicitly pads unseen values 089 * with zero. 090 * @param example The example to transform. 091 * @param featureIDMap The feature id mapping to use. 092 * @return A 3d array, (width,height,channels) representing the example. 093 */ 094 float[][][] innerTransform(Example<T> example, ImmutableFeatureMap featureIDMap) { 095 float[][][] image = new float[width][height][channels]; 096 097 for (Feature f : example) { 098 int id = featureIDMap.getID(f.getName()); 099 int curWidth = id % width; 100 int curHeight = (id / width) % height; 101 int curChannel = id / (width * height); 102 image[curWidth][curHeight][curChannel] = (float) f.getValue(); 103 } 104 105 return image; 106 } 107 108 /** 109 * Actually performs the transformation. Implicitly pads unseen values 110 * with zero. 111 * @param vector The vector to transform. 112 * @return A 3d array, (width,height,channels) representing the vector. 113 */ 114 float[][][] innerTransform(SparseVector vector) { 115 float[][][] image = new float[width][height][channels]; 116 117 for (VectorTuple f : vector) { 118 int id = f.index; 119 int curWidth = id % width; 120 int curHeight = (id / width) % height; 121 int curChannel = id / (width * height); 122 image[curWidth][curHeight][curChannel] = (float) f.value; 123 } 124 125 return image; 126 } 127 128 /** 129 * Transform implicitly pads unseen values with zero. 130 * <p> 131 * Converts a batch of examples into a Tensor. 132 * @param examples The examples to transform. 133 * @param featureIDMap The feature id mapping to use. 134 * @return A 4d tensor, (batch-id, width, height, channels) for this example. 135 */ 136 @Override 137 public Tensor<?> transform(List<Example<T>> examples, ImmutableFeatureMap featureIDMap) { 138 float[][][][] image = new float[examples.size()][][][]; 139 140 int i = 0; 141 for (Example<T> example : examples) { 142 image[i] = innerTransform(example,featureIDMap); 143 i++; 144 } 145 146 return Tensor.create(image); 147 } 148 149 @Override 150 public Tensor<?> transform(SparseVector vector) { 151 float[][][][] image = new float[1][][][]; 152 image[0] = innerTransform(vector); 153 return Tensor.create(image); 154 } 155 156 @Override 157 public Tensor<?> transform(List<SparseVector> vectors) { 158 float[][][][] image = new float[vectors.size()][][][]; 159 160 int i = 0; 161 for (SparseVector vector : vectors) { 162 image[i] = innerTransform(vector); 163 i++; 164 } 165 166 return Tensor.create(image); 167 } 168 169 @Override 170 public String toString() { 171 return "ImageTransformer(width="+width+",height="+height+",channels="+channels+")"; 172 } 173 174 @Override 175 public ConfiguredObjectProvenance getProvenance() { 176 return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer"); 177 } 178}