001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow.sequence;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
027import org.tribuo.ImmutableFeatureMap;
028import org.tribuo.ImmutableOutputInfo;
029import org.tribuo.Output;
030import org.tribuo.interop.tensorflow.TensorflowUtil;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.provenance.SkeletalTrainerProvenance;
033import org.tribuo.provenance.TrainerProvenance;
034import org.tribuo.sequence.SequenceDataset;
035import org.tribuo.sequence.SequenceExample;
036import org.tribuo.sequence.SequenceModel;
037import org.tribuo.sequence.SequenceTrainer;
038import org.tribuo.util.Util;
039import org.tensorflow.Graph;
040import org.tensorflow.Session;
041import org.tensorflow.Tensor;
042import org.tensorflow.TensorFlowException;
043
044import java.io.IOException;
045import java.nio.file.Files;
046import java.nio.file.Path;
047import java.time.Instant;
048import java.time.OffsetDateTime;
049import java.time.ZoneId;
050import java.util.ArrayList;
051import java.util.Collections;
052import java.util.Map;
053import java.util.SplittableRandom;
054import java.util.logging.Level;
055import java.util.logging.Logger;
056
057/**
058 * A trainer for SequenceModels which use an underlying Tensorflow graph.
059 */
060public class TensorflowSequenceTrainer<T extends Output<T>> implements SequenceTrainer<T> {
061    
062    private static final Logger log = Logger.getLogger(TensorflowSequenceTrainer.class.getName());
063
064    @Config(mandatory=true,description="Path to the protobuf containing the Tensorflow graph.")
065    protected Path graphPath;
066
067    private byte[] graphDef;
068
069    @Config(mandatory=true,description="Sequence feature extractor.")
070    protected SequenceExampleTransformer<T> exampleTransformer;
071    @Config(mandatory=true,description="Sequence output extractor.")
072    protected SequenceOutputTransformer<T> outputTransformer;
073
074    @Config(description="Minibatch size")
075    protected int minibatchSize = 1;
076    @Config(description="Number of SGD epochs to run.")
077    protected int epochs = 5;
078    @Config(description="Logging interval to print the loss.")
079    protected int loggingInterval = 100;
080    @Config(description="Seed for the RNG.")
081    protected long seed = 1;
082
083    @Config(mandatory=true,description="Name of the initialisation operation.")
084    protected String initOp;
085    @Config(mandatory=true,description="Name of the training operation.")
086    protected String trainOp;
087    @Config(mandatory=true,description="Name of the loss operation (to inspect the loss).")
088    protected String getLossOp;
089    @Config(mandatory=true,description="Name of the prediction operation.")
090    protected String predictOp;
091
092    protected SplittableRandom rng;
093
094    protected int trainInvocationCounter;
095
096    public TensorflowSequenceTrainer(Path graphPath,
097                                     SequenceExampleTransformer<T> exampleTransformer,
098                                     SequenceOutputTransformer<T> outputTransformer,
099                                     int minibatchSize,
100                                     int epochs,
101                                     int loggingInterval,
102                                     long seed,
103                                     String initOp,
104                                     String trainOp,
105                                     String getLossOp,
106                                     String predictOp) throws IOException {
107        this.graphPath = graphPath;
108        this.exampleTransformer = exampleTransformer;
109        this.outputTransformer = outputTransformer;
110        this.minibatchSize = minibatchSize;
111        this.epochs = epochs;
112        this.loggingInterval = loggingInterval;
113        this.seed = seed;
114        this.initOp = initOp;
115        this.trainOp = trainOp;
116        this.getLossOp = getLossOp;
117        this.predictOp = predictOp;
118        postConfig();
119    }
120
121    /** Constructor required by olcut config system. **/
122    private TensorflowSequenceTrainer() { }
123
124    @Override
125    public synchronized void postConfig() throws IOException {
126        rng = new SplittableRandom(seed);
127        graphDef = Files.readAllBytes(graphPath);
128    }
129
130    @Override
131    public SequenceModel<T> train(SequenceDataset<T> examples, Map<String,Provenance> runProvenance) {
132        // Creates a new RNG, adds one to the invocation count.
133        SplittableRandom localRNG;
134        TrainerProvenance provenance;
135        synchronized(this) {
136            localRNG = rng.split();
137            provenance = getProvenance();
138            trainInvocationCounter++;
139        }
140        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
141        ImmutableOutputInfo<T> labelMap = examples.getOutputIDInfo();
142        ArrayList<SequenceExample<T>> batch = new ArrayList<>();
143
144        int[] indices = Util.randperm(examples.size(), localRNG);
145
146        try (Graph graph = new Graph();
147             Session session = new Session(graph)) {
148            //
149            // Load the graph def into the session.
150            graph.importGraphDef(graphDef);
151            //
152            // Initialise the variables.
153            session.runner().addTarget(initOp).run();
154            log.info("Initialised the model parameters");
155            //
156            // Run additional initialization routines, if needed.
157            preTrainingHook(session, examples);
158
159            int interval = 0;
160            for (int i = 0; i < epochs; i++) {
161                log.log(Level.INFO,"Starting epoch " + i);
162
163                // Shuffle the order in which we'll look at examples
164                Util.randpermInPlace(indices, localRNG);
165
166                for (int j = 0; j < examples.size(); j += minibatchSize) {
167                    batch.clear();
168                    for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) {
169                        int ix = indices[k];
170                        batch.add(examples.getExample(ix));
171                    }
172                    //
173                    // Transform examples to tensors
174                    Map<String, Tensor<?>> feed = exampleTransformer.encode(batch, featureMap);
175                    //
176                    // Add supervision
177                    feed.putAll(outputTransformer.encode(batch, labelMap));
178                    //
179                    // Add any additional training hyperparameter values to the feed dict.
180                    feed.putAll(getHyperparameterFeed());
181                    //
182                    // Populate the runner.
183                    Session.Runner runner = session.runner();
184                    for (Map.Entry<String, Tensor<?>> item : feed.entrySet()) {
185                        runner.feed(item.getKey(), item.getValue());
186                    }
187                    //
188                    // Run a training batch.
189                    Tensor<?> loss = runner
190                            .addTarget(trainOp)
191                            .fetch(getLossOp)
192                            .run()
193                            .get(0);
194                    if (interval % loggingInterval == 0) {
195                        log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]",
196                                loss.floatValue(), i, interval, j, Math.min(examples.size(), j+minibatchSize), examples.size()));
197                    }
198                    interval++;
199                    //
200                    // Cleanup: close the tensors.
201                    loss.close();
202                    for (Tensor<?> tns : feed.values()) {
203                        tns.close();
204                    }
205                }
206            }
207            
208            // This call **must** happen before the trainedGraphDef is generated.
209            TensorflowUtil.annotateGraph(graph,session);
210            //
211            // Generate the trained graph def.
212            byte[] trainedGraphDef = graph.toGraphDef();
213            Map<String,Object> tensorMap = TensorflowUtil.serialise(graph,session);
214            ModelProvenance modelProvenance = new ModelProvenance(TensorflowSequenceModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), provenance, runProvenance);
215            return new TensorflowSequenceModel<>(
216                    "tf-sequence-model",
217                    modelProvenance,
218                    featureMap,
219                    labelMap,
220                    trainedGraphDef,
221                    exampleTransformer,
222                    outputTransformer,
223                    initOp,
224                    predictOp,
225                    tensorMap
226            );
227
228        } catch (TensorFlowException e) {
229            log.log(Level.SEVERE, "TensorFlow threw an error", e);
230            throw new IllegalStateException(e);
231        }
232    }
233
234    @Override
235    public int getInvocationCount() {
236        return trainInvocationCounter;
237    }
238
239    @Override
240    public String toString() {
241        return "TensorflowSequenceTrainer(graphPath="+graphPath.toString()+",exampleTransformer="
242                +exampleTransformer.toString()+",outputTransformer="+outputTransformer.toString()
243                +",minibatchSize="+ minibatchSize +",epochs="+ epochs +",seed="+seed+")";
244    }
245
246    protected void preTrainingHook(Session session, SequenceDataset<T> examples) {}
247
248    protected Map<String, Tensor<?>> getHyperparameterFeed() {
249        return Collections.emptyMap();
250    }
251
252    @Override
253    public TrainerProvenance getProvenance() {
254        return new TensorflowSequenceTrainerProvenance(this);
255    }
256
257    public static class TensorflowSequenceTrainerProvenance extends SkeletalTrainerProvenance {
258        private static final long serialVersionUID = 1L;
259
260        public static final String GRAPH_HASH = "graph-hash";
261        public static final String GRAPH_LAST_MOD = "graph-last-modified";
262
263        private final StringProvenance graphHash;
264        private final DateTimeProvenance graphLastModified;
265
266        <T extends Output<T>> TensorflowSequenceTrainerProvenance(TensorflowSequenceTrainer<T> host) {
267            super(host);
268            // instance parameters
269            this.graphHash = new StringProvenance(GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath));
270            this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault()));
271        }
272
273        public TensorflowSequenceTrainerProvenance(Map<String,Provenance> map) {
274            this(extractTFProvenanceInfo(map));
275        }
276
277        private TensorflowSequenceTrainerProvenance(ExtractedInfo info) {
278            super(info);
279            this.graphHash = (StringProvenance) info.instanceValues.get(GRAPH_HASH);
280            this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD);
281        }
282
283        @Override
284        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
285            Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues();
286
287            map.put(graphHash.getKey(),graphHash);
288            map.put(graphLastModified.getKey(),graphLastModified);
289
290            return map;
291        }
292
293        protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) {
294            ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map);
295
296            if (info.configuredParameters.containsKey(GRAPH_HASH)) {
297                Provenance tmpProv = info.configuredParameters.remove(GRAPH_HASH);
298                if (tmpProv instanceof HashProvenance) {
299                    info.instanceValues.put(GRAPH_HASH,(HashProvenance) tmpProv);
300                } else {
301                    throw new ProvenanceException(GRAPH_HASH + " was not of type HashProvenance in class " + info.className);
302                }
303            } else {
304                throw new ProvenanceException("Failed to find " + GRAPH_HASH + " when constructing SkeletalTrainerProvenance");
305            }
306            if (info.configuredParameters.containsKey(GRAPH_LAST_MOD)) {
307                Provenance tmpProv = info.configuredParameters.remove(GRAPH_LAST_MOD);
308                if (tmpProv instanceof DateTimeProvenance) {
309                    info.instanceValues.put(GRAPH_LAST_MOD,(DateTimeProvenance) tmpProv);
310                } else {
311                    throw new ProvenanceException(GRAPH_LAST_MOD + " was not of type DateTimeProvenance in class " + info.className);
312                }
313            } else {
314                throw new ProvenanceException("Failed to find " + GRAPH_LAST_MOD + " when constructing SkeletalTrainerProvenance");
315            }
316
317            return info;
318        }
319    }
320}