001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import org.tensorflow.Graph;
020import org.tensorflow.Operation;
021import org.tensorflow.OperationBuilder;
022import org.tensorflow.Session;
023import org.tensorflow.Tensor;
024
025import java.util.ArrayList;
026import java.util.HashMap;
027import java.util.Iterator;
028import java.util.List;
029import java.util.Map;
030import java.util.logging.Level;
031import java.util.logging.Logger;
032
033/**
034 * Helper functions for working with Tensorflow.
035 */
036public class TensorflowUtil {
037    private static final Logger logger = Logger.getLogger(TensorflowUtil.class.getName());
038
039    public static final String VARIABLE_V2 = "VariableV2";
040    public static final String ASSIGN_OP = "Assign";
041    public static final String ASSIGN_PLACEHOLDER = "Assign_from_Placeholder";
042    public static final String PLACEHOLDER = "Placeholder";
043    public static final String DTYPE = "dtype";
044
045    /**
046     * Creates a new primitive boolean array of up to 8 dimensions, using the supplied shape.
047     * <p>
048     * Does not check the shape to see if all it's elements are positive.
049     *
050     * @param shape The shape of array to create.
051     * @return A boolean array.
052     */
053    public static Object newBooleanArray(long[] shape) {
054        switch (shape.length) {
055            case 1:
056                return new boolean[(int) shape[0]];
057            case 2:
058                return new boolean[(int) shape[0]][(int) shape[1]];
059            case 3:
060                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]];
061            case 4:
062                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
063            case 5:
064                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
065            case 6:
066                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
067            case 7:
068                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
069            case 8:
070                return new boolean[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
071            default:
072                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
073        }
074    }
075
076    /**
077     * Creates a new primitive byte array of up to 8 dimensions, using the supplied shape.
078     * <p>
079     * Does not check the shape to see if all it's elements are positive.
080     *
081     * @param shape The shape of array to create.
082     * @return A byte array.
083     */
084    public static Object newByteArray(long[] shape) {
085        switch (shape.length) {
086            case 1:
087                return new byte[(int) shape[0]];
088            case 2:
089                return new byte[(int) shape[0]][(int) shape[1]];
090            case 3:
091                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]];
092            case 4:
093                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
094            case 5:
095                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
096            case 6:
097                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
098            case 7:
099                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
100            case 8:
101                return new byte[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
102            default:
103                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
104        }
105    }
106
107    /**
108     * Creates a new primitive int array of up to 8 dimensions, using the supplied shape.
109     * <p>
110     * Does not check the shape to see if all it's elements are positive.
111     *
112     * @param shape The shape of array to create.
113     * @return A int array.
114     */
115    public static Object newIntArray(long[] shape) {
116        switch (shape.length) {
117            case 1:
118                return new int[(int) shape[0]];
119            case 2:
120                return new int[(int) shape[0]][(int) shape[1]];
121            case 3:
122                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]];
123            case 4:
124                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
125            case 5:
126                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
127            case 6:
128                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
129            case 7:
130                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
131            case 8:
132                return new int[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
133            default:
134                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
135        }
136    }
137
138    /**
139     * Creates a new primitive long array of up to 8 dimensions, using the supplied shape.
140     * <p>
141     * Does not check the shape to see if all it's elements are positive.
142     *
143     * @param shape The shape of array to create.
144     * @return A long array.
145     */
146    public static Object newLongArray(long[] shape) {
147        switch (shape.length) {
148            case 1:
149                return new long[(int) shape[0]];
150            case 2:
151                return new long[(int) shape[0]][(int) shape[1]];
152            case 3:
153                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]];
154            case 4:
155                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
156            case 5:
157                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
158            case 6:
159                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
160            case 7:
161                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
162            case 8:
163                return new long[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
164            default:
165                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
166        }
167    }
168
169    /**
170     * Creates a new primitive float array of up to 8 dimensions, using the supplied shape.
171     * <p>
172     * Does not check the shape to see if all it's elements are positive.
173     *
174     * @param shape The shape of array to create.
175     * @return A float array.
176     */
177    public static Object newFloatArray(long[] shape) {
178        switch (shape.length) {
179            case 1:
180                return new float[(int) shape[0]];
181            case 2:
182                return new float[(int) shape[0]][(int) shape[1]];
183            case 3:
184                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]];
185            case 4:
186                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
187            case 5:
188                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
189            case 6:
190                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
191            case 7:
192                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
193            case 8:
194                return new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
195            default:
196                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
197        }
198    }
199
200    /**
201     * Creates a new primitive double array of up to 8 dimensions, using the supplied shape.
202     * <p>
203     * Does not check the shape to see if all it's elements are positive.
204     *
205     * @param shape The shape of array to create.
206     * @return A double array.
207     */
208    public static Object newDoubleArray(long[] shape) {
209        switch (shape.length) {
210            case 1:
211                return new double[(int) shape[0]];
212            case 2:
213                return new double[(int) shape[0]][(int) shape[1]];
214            case 3:
215                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]];
216            case 4:
217                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]];
218            case 5:
219                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]];
220            case 6:
221                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]];
222            case 7:
223                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]];
224            case 8:
225                return new double[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]][(int) shape[4]][(int) shape[5]][(int) shape[6]][(int) shape[7]];
226            default:
227                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
228        }
229    }
230
231    /**
232     * Closes a list of {@link Tensor}s.
233     *
234     * @param tensorList The list of tensors to close.
235     */
236    public static void closeTensorList(List<Tensor<?>> tensorList) {
237        for (Tensor<?> t : tensorList) {
238            t.close();
239        }
240    }
241
242    /**
243     * Extracts the appropriate type of primitive array from a {@link Tensor}.
244     * <p>
245     * Returns an object as the user doesn't know what type is in the {@link Tensor}.
246     *
247     * @param tensor The tensor to read.
248     * @return A primitive array.
249     */
250    public static Object convertTensorToArray(Tensor<?> tensor) {
251        long[] shape = tensor.shape();
252
253        Object array;
254        switch (tensor.dataType()) {
255            case FLOAT:
256                array = newFloatArray(shape);
257                break;
258            case DOUBLE:
259                array = newDoubleArray(shape);
260                break;
261            case INT32:
262                array = newIntArray(shape);
263                break;
264            case UINT8:
265            case STRING:
266                array = newByteArray(shape);
267                break;
268            case INT64:
269                array = newLongArray(shape);
270                break;
271            case BOOL:
272                array = newBooleanArray(shape);
273                break;
274            default:
275                throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType());
276        }
277
278        tensor.copyTo(array);
279
280        return array;
281    }
282
283    /**
284     * Converts a {@link Tensor} into a scalar object, boxing the primitive types.
285     * <p>
286     * Does not close the Tensor.
287     *
288     * @param tensor The tensor to convert.
289     * @return A boxed scalar.
290     */
291    public static Object convertTensorToScalar(Tensor<?> tensor) {
292        Object scalar;
293        switch (tensor.dataType()) {
294            case FLOAT:
295                scalar = tensor.floatValue();
296                break;
297            case DOUBLE:
298                scalar = tensor.doubleValue();
299                break;
300            case INT32:
301                scalar = tensor.intValue();
302                break;
303            case UINT8:
304                scalar = (byte) (tensor.intValue() & 0xFF);
305                break;
306            case STRING:
307                scalar = tensor.bytesValue();
308                break;
309            case INT64:
310                scalar = tensor.longValue();
311                break;
312            case BOOL:
313                scalar = tensor.booleanValue();
314                break;
315            default:
316                throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType());
317        }
318        return scalar;
319    }
320
321    /**
322     * Annotates a graph with an extra placeholder and assign operation for each
323     * VariableV2. This allows the graph to be deserialised using {@link TensorflowUtil#deserialise(Session, Map)}.
324     * <p>
325     * This operation can either be done each time the Graph is loaded before deserialise is called,
326     * or once, and the updated graphDef persisted with the Map produced by serialise.
327     * <p>
328     * Requires a session to correctly get the output type of a VariableV2. This isn't strictly necessary,
329     * but there aren't typed ways to get outputs in the TF version we use.
330     *
331     * @param graph The graph to annotate.
332     * @param session The session to use.
333     */
334    public static void annotateGraph(Graph graph, Session session) {
335        List<String> variableNames = new ArrayList<>();
336        Map<String, Operation> opMap = new HashMap<>();
337        Iterator<Operation> opItr = graph.operations();
338        while (opItr.hasNext()) {
339            Operation op = opItr.next();
340            if (op.type().equals(VARIABLE_V2)) {
341                variableNames.add(op.name());
342                opMap.put(op.name(), op);
343            }
344        }
345
346        Session.Runner runner = session.runner();
347        for (String s : variableNames) {
348            runner.fetch(s);
349        }
350
351        List<Tensor<?>> output = runner.run();
352
353        if (output.size() != variableNames.size()) {
354            TensorflowUtil.closeTensorList(output);
355            throw new IllegalStateException("Failed to annotate all requested variables. Requested " + variableNames.size() + ", found " + output.size());
356        }
357
358        for (int i = 0; i < output.size(); i++) {
359            OperationBuilder builder = graph.opBuilder(PLACEHOLDER, generatePlaceholderName(variableNames.get(i)));
360            builder.setAttr(DTYPE, output.get(i).dataType());
361            Operation o = builder.build();
362            builder = graph.opBuilder(ASSIGN_OP, variableNames.get(i) + "/" + ASSIGN_PLACEHOLDER);
363            builder.addInput(opMap.get(variableNames.get(i)).output(0));
364            builder.addInput(o.output(0));
365            builder.build();
366        }
367
368        TensorflowUtil.closeTensorList(output);
369    }
370
371    /**
372     * Creates a name for a placeholder based on the supplied variable name.
373     *
374     * @param variableName The variable name to use as a base.
375     * @return A name for the placeholder.
376     */
377    public static String generatePlaceholderName(String variableName) {
378        return variableName + "-" + PLACEHOLDER;
379    }
380
381    /**
382     * Extracts a Map containing the name of each Tensorflow VariableV2 and the
383     * associated parameter array. This map can then be serialised to disk.
384     *
385     * @param graph The graph to read operations from.
386     * @param session The session to read from.
387     * @return A map containing all variable names and parameter arrays.
388     */
389    public static Map<String, Object> serialise(Graph graph, Session session) {
390        List<String> variableNames = new ArrayList<>();
391        Iterator<Operation> opItr = graph.operations();
392        while (opItr.hasNext()) {
393            Operation op = opItr.next();
394            if (op.type().equals(VARIABLE_V2)) {
395                variableNames.add(op.name());
396            }
397        }
398
399        Session.Runner runner = session.runner();
400        for (String s : variableNames) {
401            runner.fetch(s);
402        }
403        List<Tensor<?>> output = runner.run();
404
405        if (output.size() != variableNames.size()) {
406            closeTensorList(output);
407            throw new IllegalStateException("Failed to serialise all requested variables. Requested " + variableNames.size() + ", found " + output.size());
408        }
409
410        Map<String, Object> tensorMap = new HashMap<>();
411        for (int i = 0; i < variableNames.size(); i++) {
412            String name = variableNames.get(i);
413            Tensor<?> tensor = output.get(i);
414            Object value;
415            if (tensor.numDimensions() == 0) {
416                value = convertTensorToScalar(tensor);
417            } else {
418                value = convertTensorToArray(tensor);
419            }
420            tensorMap.put(name, value);
421        }
422
423        closeTensorList(output);
424
425        return tensorMap;
426    }
427
428    /**
429     * Writes a map containing the name of each Tensorflow VariableV2 and the associated
430     * parameter array into the supplied session.
431     *
432     * @param session   The session to write to.
433     * @param tensorMap The parameter map to write.
434     */
435    public static void deserialise(Session session, Map<String, Object> tensorMap) {
436        Session.Runner runner = session.runner();
437        List<Tensor<?>> tensors = new ArrayList<>();
438        for (Map.Entry<String, Object> e : tensorMap.entrySet()) {
439            logger.log(Level.FINEST, "Loading " + e.getKey() + " of type " + e.getValue().getClass().getName());
440            Tensor<?> tensor = Tensor.create(e.getValue());
441            runner.feed(generatePlaceholderName(e.getKey()), tensor);
442            runner.addTarget(e.getKey() + "/" + ASSIGN_PLACEHOLDER);
443            tensors.add(tensor);
444        }
445        runner.run();
446        closeTensorList(tensors);
447    }
448}