001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
026import org.tribuo.Dataset;
027import org.tribuo.Example;
028import org.tribuo.ImmutableFeatureMap;
029import org.tribuo.ImmutableOutputInfo;
030import org.tribuo.Model;
031import org.tribuo.Output;
032import org.tribuo.Trainer;
033import org.tribuo.provenance.ModelProvenance;
034import org.tribuo.provenance.SkeletalTrainerProvenance;
035import org.tribuo.provenance.TrainerProvenance;
036import org.tensorflow.Graph;
037import org.tensorflow.Session;
038import org.tensorflow.Tensor;
039import org.tensorflow.TensorFlowException;
040
041import java.io.IOException;
042import java.nio.file.Files;
043import java.nio.file.Path;
044import java.time.Instant;
045import java.time.OffsetDateTime;
046import java.time.ZoneId;
047import java.util.ArrayList;
048import java.util.Map;
049import java.util.logging.Level;
050import java.util.logging.Logger;
051
052/**
053 * Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and
054 * targets listed below.
055 *
056 * <ul>
057 * <li>{@link TensorflowModel#INPUT_NAME} - the input minibatch.</li>
058 * <li>{@link TensorflowModel#OUTPUT_NAME} - the predicted output.</li>
059 * <li>{@link TensorflowTrainer#TARGET} - the output to predict.</li>
060 * <li>{@link TensorflowTrainer#TRAIN} - the train function to run (usually a single step of SGD).</li>
061 * <li>{@link TensorflowTrainer#TRAINING_LOSS} - the loss tensor to extract for logging.</li>
062 * <li>{@link TensorflowTrainer#EPOCH} - the current epoch number, used for gradient scaling.</li>
063 * <li>{@link TensorflowTrainer#IS_TRAINING} - a boolean placeholder to turn on dropout or other training specific functionality.</li>
064 * <li>{@link TensorflowTrainer#INIT} - the function to initialise the graph.</li>
065 * </ul>
066 *
067 * This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1.
068 * <p>
069 * This trainer uses the serialisation functionality in {@link TensorflowUtil}, as opposed to a SavedModel or a checkpoint.
070 * <p>
071 * N.B. Tensorflow support is experimental and may change without a major version bump.
072 */
073public final class TensorflowTrainer<T extends Output<T>> implements Trainer<T> {
074
075    private static final Logger logger = Logger.getLogger(TensorflowTrainer.class.getName());
076
077    public static final String TARGET = "target";
078    public static final String TRAIN = "train";
079    public static final String TRAINING_LOSS = "training_loss";
080    public static final String EPOCH = "epoch";
081    public static final String IS_TRAINING = "is_training";
082    public static final String INIT = "init";
083
084    @Config(mandatory=true,description="Path to the protobuf containing the graph.")
085    private Path graphPath;
086
087    private byte[] graphDef;
088
089    @Config(description="Test time batch size.")
090    private int testBatchSize = 16;
091
092    @Config(mandatory=true,description="Feature extractor.")
093    private ExampleTransformer<T> exampleTransformer;
094
095    @Config(mandatory=true,description="Response extractor.")
096    private OutputTransformer<T> outputTransformer;
097
098    @Config(description="Minibatch size.")
099    private int minibatchSize = 1;
100
101    @Config(description="Number of SGD epochs to run.")
102    private int epochs = 5;
103
104    @Config(description="Logging interval to print out the loss.")
105    private int loggingInterval = 100;
106
107    private int trainInvocationCounter = 0;
108
109    /**
110     * for olcut.
111     */
112    private TensorflowTrainer() {}
113
114    /**
115     * Constructs a Trainer for a tensorflow graph.
116     * @param graphPath The path to the graph protobuf. Must have the targets and placeholders specified above.
117     * @param exampleTransformer The example transformer to convert a Tribuo {@link Example} into a {@link Tensor}.
118     * @param outputTransformer The output transformer to convert a Tribuo {@link Output} into a {@link Tensor} and back. This encodes the output type.
119     * @param minibatchSize The minibatch size to use in training.
120     * @param epochs The number of SGD epochs to run.
121     * @param testBatchSize The minibatch size to use at test time.
122     * @throws IOException If the graphPath is invalid or failed to load.
123     */
124    public TensorflowTrainer(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) throws IOException {
125        this.graphPath = graphPath;
126        this.exampleTransformer = exampleTransformer;
127        this.outputTransformer = outputTransformer;
128        this.minibatchSize = minibatchSize;
129        this.epochs = epochs;
130        this.testBatchSize = testBatchSize;
131        postConfig();
132    }
133
134    /**
135     * Constructs a Trainer for a tensorflow graph.
136     * @param graphDef The graph definition as a byte array. Must have the targets and placeholders specified above.
137     * @param exampleTransformer The example transformer to convert a Tribuo {@link Example} into a {@link Tensor}.
138     * @param outputTransformer The output transformer to convert a Tribuo {@link Output} into a {@link Tensor} and back. This encodes the output type.
139     * @param minibatchSize The minibatch size to use in training.
140     * @param epochs The number of SGD epochs to run.
141     * @param testBatchSize The minibatch size to use at test time.
142     */
143    public TensorflowTrainer(byte[] graphDef, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) {
144        this.graphPath = null;
145        this.graphDef = graphDef;
146        this.exampleTransformer = exampleTransformer;
147        this.outputTransformer = outputTransformer;
148        this.minibatchSize = minibatchSize;
149        this.epochs = epochs;
150        this.testBatchSize = testBatchSize;
151    }
152
153    /**
154     * Used by the OLCUT configuration system, and should not be called by external code.
155     */
156    @Override
157    public void postConfig() throws IOException {
158        graphDef = Files.readAllBytes(graphPath);
159    }
160
161    @Override
162    public Model<T> train(Dataset<T> examples, Map<String,Provenance> runProvenance) {
163        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
164        ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo();
165        ArrayList<Example<T>> batch = new ArrayList<>();
166        trainInvocationCounter++;
167
168        try (Graph graph = new Graph();
169             Session session = new Session(graph);
170             Tensor<?> isTraining = Tensor.create(true)) {
171            // Load in the graph definition
172            graph.importGraphDef(graphDef);
173
174            // Initialises the parameters.
175            session.runner().addTarget(INIT).run();
176            logger.info("Initialised the model parameters");
177
178            int interval = 0;
179            for (int i = 0; i < epochs; i++) {
180                logger.log(Level.INFO,"Starting epoch " + i);
181                Tensor<?> epoch = Tensor.create(i);
182                for (int j = 0; j < examples.size(); j += minibatchSize) {
183                    batch.clear();
184                    for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) {
185                        batch.add(examples.getExample(k));
186                    }
187                    //logger.info("Batch = " + batch.size());
188                    Tensor<?> input = exampleTransformer.transform(batch,featureMap);
189                    Tensor<?> target = outputTransformer.transform(batch,outputInfo);
190                    Tensor<?> loss = session.runner()
191                            .feed(TensorflowModel.INPUT_NAME, input)
192                            .feed(TARGET, target)
193                            .feed(EPOCH, epoch)
194                            .feed(IS_TRAINING, isTraining)
195                            .addTarget(TRAIN)
196                            .fetch(TRAINING_LOSS)
197                            .run().get(0);
198                    if (interval % loggingInterval == 0) {
199                        logger.log(Level.INFO, "Training loss = " + loss.floatValue());
200                    }
201                    input.close();
202                    target.close();
203                    loss.close();
204                    interval++;
205                }
206                epoch.close();
207            }
208
209            //System.out.println("After training");
210            //TensorflowModel.print(session);
211
212            // This call **must** happen before the trainedGraphDef is generated.
213            TensorflowUtil.annotateGraph(graph,session);
214
215            byte[] trainedGraphDef = graph.toGraphDef();
216
217            Map<String,Object> tensorMap = TensorflowUtil.serialise(graph,session);
218
219            ModelProvenance modelProvenance = new ModelProvenance(TensorflowModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
220            TensorflowModel<T> tfModel = new TensorflowModel<>("tf-model", modelProvenance, featureMap,
221                    outputInfo, trainedGraphDef, tensorMap, testBatchSize, exampleTransformer, outputTransformer);
222
223            return tfModel;
224        } catch (TensorFlowException e) {
225            logger.log(Level.SEVERE, "TensorFlow threw an error", e);
226            throw new IllegalStateException(e);
227        }
228    }
229
230    @Override
231    public String toString() {
232        String path = graphPath==null?"":graphPath.toString();
233        return "TensorflowTrainer(graphPath="+path+",exampleTransformer="
234                +exampleTransformer.toString()+",outputTransformer="+outputTransformer.toString()
235                +",minibatchSize="+ minibatchSize +",epochs="+ epochs +")";
236    }
237
238    @Override
239    public int getInvocationCount() {
240        return trainInvocationCounter;
241    }
242
243    @Override
244    public TrainerProvenance getProvenance() {
245        return new TensorflowTrainerProvenance(this);
246    }
247
248    public static final class TensorflowTrainerProvenance extends SkeletalTrainerProvenance {
249        private static final long serialVersionUID = 1L;
250
251        public static final String GRAPH_HASH = "graph-hash";
252        public static final String GRAPH_LAST_MOD = "graph-last-modified";
253
254        private final HashProvenance graphHash;
255        private final DateTimeProvenance graphLastModified;
256
257        <T extends Output<T>> TensorflowTrainerProvenance(TensorflowTrainer<T> host) {
258            super(host);
259            // instance parameters
260            if (host.graphPath != null) {
261                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath));
262                this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault()));
263            } else {
264                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,hashArray(DEFAULT_HASH_TYPE,host.graphDef));
265                this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.now());
266            }
267        }
268
269        public TensorflowTrainerProvenance(Map<String,Provenance> map) {
270            this(extractTFProvenanceInfo(map));
271        }
272
273        private TensorflowTrainerProvenance(ExtractedInfo info) {
274            super(info);
275            this.graphHash = (HashProvenance) info.instanceValues.get(GRAPH_HASH);
276            this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD);
277        }
278
279        /**
280         * Hashes a byte array using the specified {@link com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil.HashType}.
281         * @param hashType The type of hash to perform.
282         * @param input The input array.
283         * @return A hexadecimal string representation of the hash.
284         */
285        private static String hashArray(com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil.HashType hashType, byte[] input) {
286            java.security.MessageDigest md = hashType.getDigest();
287            md.update(input);
288            return ProvenanceUtil.bytesToHexString(md.digest());
289        }
290
291        @Override
292        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
293            Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues();
294
295            map.put(graphHash.getKey(),graphHash);
296            map.put(graphLastModified.getKey(),graphLastModified);
297
298            return map;
299        }
300
301        protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) {
302            ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map);
303            info.instanceValues.put(GRAPH_HASH,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_HASH,HashProvenance.class,TensorflowTrainerProvenance.class.getSimpleName()));
304            info.instanceValues.put(GRAPH_LAST_MOD,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_LAST_MOD,DateTimeProvenance.class,TensorflowTrainerProvenance.class.getSimpleName()));
305            return info;
306        }
307    }
308}