001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
026import org.tribuo.Dataset;
027import org.tribuo.Example;
028import org.tribuo.ImmutableFeatureMap;
029import org.tribuo.ImmutableOutputInfo;
030import org.tribuo.Model;
031import org.tribuo.Output;
032import org.tribuo.Trainer;
033import org.tribuo.interop.tensorflow.TensorflowTrainer.TensorflowTrainerProvenance;
034import org.tribuo.provenance.ModelProvenance;
035import org.tribuo.provenance.SkeletalTrainerProvenance;
036import org.tribuo.provenance.TrainerProvenance;
037import org.tensorflow.Graph;
038import org.tensorflow.Session;
039import org.tensorflow.Tensor;
040import org.tensorflow.TensorFlowException;
041import org.tensorflow.Tensors;
042
043import java.io.IOException;
044import java.nio.file.Files;
045import java.nio.file.Path;
046import java.nio.file.Paths;
047import java.time.Instant;
048import java.time.OffsetDateTime;
049import java.time.ZoneId;
050import java.util.ArrayList;
051import java.util.Map;
052import java.util.logging.Level;
053import java.util.logging.Logger;
054
055/**
056 * Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and
057 * targets listed below.
058 *
059 * <ul>
060 * <li>{@link TensorflowModel#INPUT_NAME} - the input minibatch.</li>
061 * <li>{@link TensorflowModel#OUTPUT_NAME} - the predicted output.</li>
062 * <li>{@link TensorflowTrainer#TARGET} - the output to predict.</li>
063 * <li>{@link TensorflowTrainer#TRAIN} - the train function to run (usually a single step of SGD).</li>
064 * <li>{@link TensorflowTrainer#TRAINING_LOSS} - the loss tensor to extract for logging.</li>
065 * <li>{@link TensorflowTrainer#EPOCH} - the current epoch number, used for gradient scaling.</li>
066 * <li>{@link TensorflowTrainer#IS_TRAINING} - a boolean placeholder to turn on dropout or other training specific functionality.</li>
067 * <li>{@link TensorflowTrainer#INIT} - the function to initialise the graph.</li>
068 * </ul>
069 *
070 * This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1.
071 * <p>
072 * This trainer uses the native Tensorflow serialisation functionality and saves to a checkpoint on disk. It's much more
073 * fragile than the {@link TensorflowTrainer}.
074 * </p>
075 * <p>
076 * N.B. Tensorflow support is experimental and may change without a major version bump.
077 */
078public final class TensorflowCheckpointTrainer<T extends Output<T>> implements Trainer<T> {
079
080    private static final Logger logger = Logger.getLogger(TensorflowCheckpointTrainer.class.getName());
081
082    public static final String MODEL_FILENAME = "model";
083
084    @Config(mandatory=true,description="Path to the protobuf containing the graph.")
085    private Path graphPath;
086
087    private byte[] graphDef;
088
089    @Config(mandatory=true,description="Feature extractor.")
090    private ExampleTransformer<T> exampleTransformer;
091
092    @Config(mandatory=true,description="Response extractor.")
093    private OutputTransformer<T> outputTransformer;
094
095    @Config(description="Minibatch size.")
096    private int minibatchSize = 1;
097
098    @Config(description="Number of SGD epochs to run.")
099    private int epochs = 5;
100
101    @Config(description="Logging interval to print out the loss.")
102    private int loggingInterval = 100;
103
104    @Config(description="Path to write out the checkpoints.")
105    private Path checkpointRootPath = Paths.get("/tmp/");
106
107    private int trainInvocationCounter = 0;
108
109    /**
110     * for olcut.
111     */
112    private TensorflowCheckpointTrainer() {}
113
114    /**
115     * Builds a trainer using the supplied graph and arguments.
116     * @param graphPath The graph to load.
117     * @param checkpointRootPath The checkpoint path to save to.
118     * @param exampleTransformer The feature transformer.
119     * @param outputTransformer The output transformer.
120     * @param minibatchSize The training batch size.
121     * @param epochs The number of training epochs.
122     * @throws IOException If the graph failed to load.
123     */
124    public TensorflowCheckpointTrainer(Path graphPath, Path checkpointRootPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs) throws IOException {
125        this.graphPath = graphPath;
126        this.checkpointRootPath = checkpointRootPath;
127        this.exampleTransformer = exampleTransformer;
128        this.outputTransformer = outputTransformer;
129        this.minibatchSize = minibatchSize;
130        this.epochs = epochs;
131        postConfig();
132    }
133
134    /**
135     * Used by the OLCUT configuration system, and should not be called by external code.
136     */
137    @Override
138    public void postConfig() throws IOException {
139        graphDef = Files.readAllBytes(graphPath);
140    }
141
142    @Override
143    public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
144        Path checkpointPath;
145        try {
146            checkpointPath = Files.createTempDirectory(checkpointRootPath,"tensorflow-checkpoint");
147        } catch (IOException e) {
148            logger.log(Level.SEVERE, "Failed to create checkpoint directory at path " + checkpointRootPath,e);
149            throw new IllegalStateException("Failed to create checkpoint directory at path " + checkpointRootPath,e);
150        }
151        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
152        ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo();
153        ArrayList<Example<T>> batch = new ArrayList<>();
154
155        trainInvocationCounter++;
156
157        try (Graph graph = new Graph();
158             Session session = new Session(graph);
159             Tensor<?> isTraining = Tensor.create(true);
160             Tensor<String> checkpointPathTensor = Tensors.create(checkpointPath.toString()+"/"+MODEL_FILENAME) ) {
161            // Load in the graph definition
162            graph.importGraphDef(graphDef);
163
164            // Initialises the parameters.
165            session.runner().addTarget(TensorflowTrainer.INIT).run();
166            logger.info("Initialised the model parameters");
167
168            int interval = 0;
169            for (int i = 0; i < epochs; i++) {
170                logger.log(Level.INFO,"Starting epoch " + i);
171                Tensor<?> epoch = Tensor.create(i);
172                for (int j = 0; j < examples.size(); j += minibatchSize) {
173                    batch.clear();
174                    for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) {
175                        batch.add(examples.getExample(k));
176                    }
177                    //logger.info("Batch = " + batch.size());
178                    Tensor<?> input = exampleTransformer.transform(batch,featureMap);
179                    Tensor<?> target = outputTransformer.transform(batch,outputInfo);
180                    Tensor<?> loss = session.runner()
181                            .feed(TensorflowModel.INPUT_NAME, input)
182                            .feed(TensorflowTrainer.TARGET, target)
183                            .feed(TensorflowTrainer.EPOCH, epoch)
184                            .feed(TensorflowTrainer.IS_TRAINING, isTraining)
185                            .addTarget(TensorflowTrainer.TRAIN)
186                            .fetch(TensorflowTrainer.TRAINING_LOSS)
187                            .run().get(0);
188                    if (interval % loggingInterval == 0) {
189                        logger.log(Level.INFO, "Training loss = " + loss.floatValue());
190                    }
191                    input.close();
192                    target.close();
193                    loss.close();
194                    interval++;
195                }
196                epoch.close();
197            }
198
199            session.runner().feed("save/Const", checkpointPathTensor).addTarget("save/control_dependency").run();
200
201            byte[] trainedGraphDef = graph.toGraphDef();
202
203            ModelProvenance modelProvenance = new ModelProvenance(TensorflowCheckpointModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
204            TensorflowCheckpointModel<T> tfModel = new TensorflowCheckpointModel<>("tf-model", modelProvenance, featureMap,
205                    outputInfo, trainedGraphDef, checkpointPath.toString(), exampleTransformer, outputTransformer);
206
207            return tfModel;
208        } catch (TensorFlowException e) {
209            logger.log(Level.SEVERE, "TensorFlow threw an error", e);
210            throw new IllegalStateException(e);
211        }
212    }
213
214    @Override
215    public int getInvocationCount() {
216        return trainInvocationCounter;
217    }
218
219    @Override
220    public String toString() {
221        return "TensorflowCheckpointTrainer(graphPath="+graphPath.toString()
222                +",checkpointRootPath="+checkpointRootPath.toString()+",exampleTransformer="
223                +exampleTransformer.toString()+",outputTransformer"+outputTransformer.toString()
224                +",minibatchSize="+ minibatchSize +",epochs="+ epochs +")";
225    }
226
227    @Override
228    public TrainerProvenance getProvenance() {
229        return new TensorflowCheckpointTrainerProvenance(this);
230    }
231
232    public static final class TensorflowCheckpointTrainerProvenance extends SkeletalTrainerProvenance {
233        private static final long serialVersionUID = 1L;
234
235        public static final String GRAPH_HASH = "graph-hash";
236        public static final String GRAPH_LAST_MOD = "graph-last-modified";
237
238        private final HashProvenance graphHash;
239        private final DateTimeProvenance graphLastModified;
240
241        <T extends Output<T>> TensorflowCheckpointTrainerProvenance(TensorflowCheckpointTrainer<T> host) {
242            super(host);
243            // instance parameters
244            this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath));
245            this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault()));
246        }
247
248        public TensorflowCheckpointTrainerProvenance(Map<String,Provenance> map) {
249            this(extractTFProvenanceInfo(map));
250        }
251
252        private TensorflowCheckpointTrainerProvenance(ExtractedInfo info) {
253            super(info);
254            this.graphHash = (HashProvenance) info.instanceValues.get(GRAPH_HASH);
255            this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD);
256        }
257
258        @Override
259        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
260            Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues();
261
262            map.put(graphHash.getKey(),graphHash);
263            map.put(graphLastModified.getKey(),graphLastModified);
264
265            return map;
266        }
267
268        protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) {
269            ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map);
270            info.instanceValues.put(GRAPH_HASH, ObjectProvenance.checkAndExtractProvenance(map,GRAPH_HASH,HashProvenance.class, TensorflowTrainerProvenance.class.getSimpleName()));
271            info.instanceValues.put(GRAPH_LAST_MOD,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_LAST_MOD,DateTimeProvenance.class,TensorflowTrainerProvenance.class.getSimpleName()));
272            return info;
273        }
274    }
275}