001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import org.tensorflow.Graph; 020import org.tensorflow.Operation; 021import org.tensorflow.OperationBuilder; 022import org.tensorflow.Session; 023import org.tensorflow.Tensor; 024 025import java.util.ArrayList; 026import java.util.HashMap; 027import java.util.Iterator; 028import java.util.List; 029import java.util.Map; 030import java.util.logging.Level; 031import java.util.logging.Logger; 032 033/** 034 * Helper functions for working with Tensorflow. 035 */ 036public class TensorflowUtil { 037 private static final Logger logger = Logger.getLogger(TensorflowUtil.class.getName()); 038 039 public static final String VARIABLE_V2 = "VariableV2"; 040 public static final String ASSIGN_OP = "Assign"; 041 public static final String ASSIGN_PLACEHOLDER = "Assign_from_Placeholder"; 042 public static final String PLACEHOLDER = "Placeholder"; 043 public static final String DTYPE = "dtype"; 044 045 /** 046 * Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape. 047 * <p> 048 * Does not check the shape to see if all it's elements are positive. 049 * 050 * @param shape The shape of array to create. 051 * @return A boolean array. 052 */ 053 public static Object newBooleanArray(long[] shape) { 054 switch (shape.length) { 055 case 1: 056 return new boolean[(int) shape[0]]; 057 case 2: 058 return new boolean[(int) shape[0]][(int) shape[1]]; 059 case 3: 060 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 061 case 4: 062 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 063 case 5: 064 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 065 case 6: 066 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 067 case 7: 068 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 069 case 8: 070 return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 071 default: 072 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 073 } 074 } 075 076 /** 077 * Creates a new primitive byte array of up to 8 dimensions, using the supplied shape. 078 * <p> 079 * Does not check the shape to see if all it's elements are positive. 080 * 081 * @param shape The shape of array to create. 082 * @return A byte array. 083 */ 084 public static Object newByteArray(long[] shape) { 085 switch (shape.length) { 086 case 1: 087 return new byte[(int) shape[0]]; 088 case 2: 089 return new byte[(int) shape[0]][(int) shape[1]]; 090 case 3: 091 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 092 case 4: 093 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 094 case 5: 095 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 096 case 6: 097 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 098 case 7: 099 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 100 case 8: 101 return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 102 default: 103 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 104 } 105 } 106 107 /** 108 * Creates a new primitive int array of up to 8 dimensions, using the supplied shape. 109 * <p> 110 * Does not check the shape to see if all it's elements are positive. 111 * 112 * @param shape The shape of array to create. 113 * @return A int array. 114 */ 115 public static Object newIntArray(long[] shape) { 116 switch (shape.length) { 117 case 1: 118 return new int[(int) shape[0]]; 119 case 2: 120 return new int[(int) shape[0]][(int) shape[1]]; 121 case 3: 122 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 123 case 4: 124 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 125 case 5: 126 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 127 case 6: 128 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 129 case 7: 130 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 131 case 8: 132 return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 133 default: 134 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 135 } 136 } 137 138 /** 139 * Creates a new primitive long array of up to 8 dimensions, using the supplied shape. 140 * <p> 141 * Does not check the shape to see if all it's elements are positive. 142 * 143 * @param shape The shape of array to create. 144 * @return A long array. 145 */ 146 public static Object newLongArray(long[] shape) { 147 switch (shape.length) { 148 case 1: 149 return new long[(int) shape[0]]; 150 case 2: 151 return new long[(int) shape[0]][(int) shape[1]]; 152 case 3: 153 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 154 case 4: 155 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 156 case 5: 157 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 158 case 6: 159 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 160 case 7: 161 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 162 case 8: 163 return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 164 default: 165 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 166 } 167 } 168 169 /** 170 * Creates a new primitive float array of up to 8 dimensions, using the supplied shape. 171 * <p> 172 * Does not check the shape to see if all it's elements are positive. 173 * 174 * @param shape The shape of array to create. 175 * @return A float array. 176 */ 177 public static Object newFloatArray(long[] shape) { 178 switch (shape.length) { 179 case 1: 180 return new float[(int) shape[0]]; 181 case 2: 182 return new float[(int) shape[0]][(int) shape[1]]; 183 case 3: 184 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 185 case 4: 186 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 187 case 5: 188 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 189 case 6: 190 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 191 case 7: 192 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 193 case 8: 194 return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 195 default: 196 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 197 } 198 } 199 200 /** 201 * Creates a new primitive double array of up to 8 dimensions, using the supplied shape. 202 * <p> 203 * Does not check the shape to see if all it's elements are positive. 204 * 205 * @param shape The shape of array to create. 206 * @return A double array. 207 */ 208 public static Object newDoubleArray(long[] shape) { 209 switch (shape.length) { 210 case 1: 211 return new double[(int) shape[0]]; 212 case 2: 213 return new double[(int) shape[0]][(int) shape[1]]; 214 case 3: 215 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]]; 216 case 4: 217 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]; 218 case 5: 219 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]]; 220 case 6: 221 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]]; 222 case 7: 223 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]]; 224 case 8: 225 return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]]; 226 default: 227 throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported."); 228 } 229 } 230 231 /** 232 * Closes a list of {@link Tensor}s. 233 * 234 * @param tensorList The list of tensors to close. 235 */ 236 public static void closeTensorList(List<Tensor<?>> tensorList) { 237 for (Tensor<?> t : tensorList) { 238 t.close(); 239 } 240 } 241 242 /** 243 * Extracts the appropriate type of primitive array from a {@link Tensor}. 244 * <p> 245 * Returns an object as the user doesn't know what type is in the {@link Tensor}. 246 * 247 * @param tensor The tensor to read. 248 * @return A primitive array. 249 */ 250 public static Object convertTensorToArray(Tensor<?> tensor) { 251 long[] shape = tensor.shape(); 252 253 Object array; 254 switch (tensor.dataType()) { 255 case FLOAT: 256 array = newFloatArray(shape); 257 break; 258 case DOUBLE: 259 array = newDoubleArray(shape); 260 break; 261 case INT32: 262 array = newIntArray(shape); 263 break; 264 case UINT8: 265 case STRING: 266 array = newByteArray(shape); 267 break; 268 case INT64: 269 array = newLongArray(shape); 270 break; 271 case BOOL: 272 array = newBooleanArray(shape); 273 break; 274 default: 275 throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType()); 276 } 277 278 tensor.copyTo(array); 279 280 return array; 281 } 282 283 /** 284 * Converts a {@link Tensor} into a scalar object, boxing the primitive types. 285 * <p> 286 * Does not close the Tensor. 287 * 288 * @param tensor The tensor to convert. 289 * @return A boxed scalar. 290 */ 291 public static Object convertTensorToScalar(Tensor<?> tensor) { 292 Object scalar; 293 switch (tensor.dataType()) { 294 case FLOAT: 295 scalar = tensor.floatValue(); 296 break; 297 case DOUBLE: 298 scalar = tensor.doubleValue(); 299 break; 300 case INT32: 301 scalar = tensor.intValue(); 302 break; 303 case UINT8: 304 scalar = (byte) (tensor.intValue() & 0xFF); 305 break; 306 case STRING: 307 scalar = tensor.bytesValue(); 308 break; 309 case INT64: 310 scalar = tensor.longValue(); 311 break; 312 case BOOL: 313 scalar = tensor.booleanValue(); 314 break; 315 default: 316 throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType()); 317 } 318 return scalar; 319 } 320 321 /** 322 * Annotates a graph with an extra placeholder and assign operation for each 323 * VariableV2. This allows the graph to be deserialised using {@link TensorflowUtil#deserialise(Session, Map)}. 324 * <p> 325 * This operation can either be done each time the Graph is loaded before deserialise is called, 326 * or once, and the updated graphDef persisted with the Map produced by serialise. 327 * <p> 328 * Requires a session to correctly get the output type of a VariableV2. This isn't strictly necessary, 329 * but there aren't typed ways to get outputs in the TF version we use. 330 * 331 * @param graph The graph to annotate. 332 * @param session The session to use. 333 */ 334 public static void annotateGraph(Graph graph, Session session) { 335 List<String> variableNames = new ArrayList<>(); 336 Map<String, Operation> opMap = new HashMap<>(); 337 Iterator<Operation> opItr = graph.operations(); 338 while (opItr.hasNext()) { 339 Operation op = opItr.next(); 340 if (op.type().equals(VARIABLE_V2)) { 341 variableNames.add(op.name()); 342 opMap.put(op.name(), op); 343 } 344 } 345 346 Session.Runner runner = session.runner(); 347 for (String s : variableNames) { 348 runner.fetch(s); 349 } 350 351 List<Tensor<?>> output = runner.run(); 352 353 if (output.size() != variableNames.size()) { 354 TensorflowUtil.closeTensorList(output); 355 throw new IllegalStateException("Failed to annotate all requested variables. Requested " + variableNames.size() + ", found " + output.size()); 356 } 357 358 for (int i = 0; i < output.size(); i++) { 359 OperationBuilder builder = graph.opBuilder(PLACEHOLDER, generatePlaceholderName(variableNames.get(i))); 360 builder.setAttr(DTYPE, output.get(i).dataType()); 361 Operation o = builder.build(); 362 builder = graph.opBuilder(ASSIGN_OP, variableNames.get(i) + "/" + ASSIGN_PLACEHOLDER); 363 builder.addInput(opMap.get(variableNames.get(i)).output(0)); 364 builder.addInput(o.output(0)); 365 builder.build(); 366 } 367 368 TensorflowUtil.closeTensorList(output); 369 } 370 371 /** 372 * Creates a name for a placeholder based on the supplied variable name. 373 * 374 * @param variableName The variable name to use as a base. 375 * @return A name for the placeholder. 376 */ 377 public static String generatePlaceholderName(String variableName) { 378 return variableName + "-" + PLACEHOLDER; 379 } 380 381 /** 382 * Extracts a Map containing the name of each Tensorflow VariableV2 and the 383 * associated parameter array. This map can then be serialised to disk. 384 * 385 * @param graph The graph to read operations from. 386 * @param session The session to read from. 387 * @return A map containing all variable names and parameter arrays. 388 */ 389 public static Map<String, Object> serialise(Graph graph, Session session) { 390 List<String> variableNames = new ArrayList<>(); 391 Iterator<Operation> opItr = graph.operations(); 392 while (opItr.hasNext()) { 393 Operation op = opItr.next(); 394 if (op.type().equals(VARIABLE_V2)) { 395 variableNames.add(op.name()); 396 } 397 } 398 399 Session.Runner runner = session.runner(); 400 for (String s : variableNames) { 401 runner.fetch(s); 402 } 403 List<Tensor<?>> output = runner.run(); 404 405 if (output.size() != variableNames.size()) { 406 closeTensorList(output); 407 throw new IllegalStateException("Failed to serialise all requested variables. Requested " + variableNames.size() + ", found " + output.size()); 408 } 409 410 Map<String, Object> tensorMap = new HashMap<>(); 411 for (int i = 0; i < variableNames.size(); i++) { 412 String name = variableNames.get(i); 413 Tensor<?> tensor = output.get(i); 414 Object value; 415 if (tensor.numDimensions() == 0) { 416 value = convertTensorToScalar(tensor); 417 } else { 418 value = convertTensorToArray(tensor); 419 } 420 tensorMap.put(name, value); 421 } 422 423 closeTensorList(output); 424 425 return tensorMap; 426 } 427 428 /** 429 * Writes a map containing the name of each Tensorflow VariableV2 and the associated 430 * parameter array into the supplied session. 431 * 432 * @param session The session to write to. 433 * @param tensorMap The parameter map to write. 434 */ 435 public static void deserialise(Session session, Map<String, Object> tensorMap) { 436 Session.Runner runner = session.runner(); 437 List<Tensor<?>> tensors = new ArrayList<>(); 438 for (Map.Entry<String, Object> e : tensorMap.entrySet()) { 439 logger.log(Level.FINEST, "Loading " + e.getKey() + " of type " + e.getValue().getClass().getName()); 440 Tensor<?> tensor = Tensor.create(e.getValue()); 441 runner.feed(generatePlaceholderName(e.getKey()), tensor); 442 runner.addTarget(e.getKey() + "/" + ASSIGN_PLACEHOLDER); 443 tensors.add(tensor); 444 } 445 runner.run(); 446 closeTensorList(tensors); 447 } 448}