001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 026import org.tribuo.Dataset; 027import org.tribuo.Example; 028import org.tribuo.ImmutableFeatureMap; 029import org.tribuo.ImmutableOutputInfo; 030import org.tribuo.Model; 031import org.tribuo.Output; 032import org.tribuo.Trainer; 033import org.tribuo.interop.tensorflow.TensorflowTrainer.TensorflowTrainerProvenance; 034import org.tribuo.provenance.ModelProvenance; 035import org.tribuo.provenance.SkeletalTrainerProvenance; 036import org.tribuo.provenance.TrainerProvenance; 037import org.tensorflow.Graph; 038import org.tensorflow.Session; 039import org.tensorflow.Tensor; 040import org.tensorflow.TensorFlowException; 041import org.tensorflow.Tensors; 042 043import java.io.IOException; 044import java.nio.file.Files; 045import java.nio.file.Path; 046import java.nio.file.Paths; 047import java.time.Instant; 048import java.time.OffsetDateTime; 049import java.time.ZoneId; 050import java.util.ArrayList; 051import java.util.Map; 052import java.util.logging.Level; 053import java.util.logging.Logger; 054 055/** 056 * Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and 057 * targets listed below. 058 * 059 * <ul> 060 * <li>{@link TensorflowModel#INPUT_NAME} - the input minibatch.</li> 061 * <li>{@link TensorflowModel#OUTPUT_NAME} - the predicted output.</li> 062 * <li>{@link TensorflowTrainer#TARGET} - the output to predict.</li> 063 * <li>{@link TensorflowTrainer#TRAIN} - the train function to run (usually a single step of SGD).</li> 064 * <li>{@link TensorflowTrainer#TRAINING_LOSS} - the loss tensor to extract for logging.</li> 065 * <li>{@link TensorflowTrainer#EPOCH} - the current epoch number, used for gradient scaling.</li> 066 * <li>{@link TensorflowTrainer#IS_TRAINING} - a boolean placeholder to turn on dropout or other training specific functionality.</li> 067 * <li>{@link TensorflowTrainer#INIT} - the function to initialise the graph.</li> 068 * </ul> 069 * 070 * This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1. 071 * <p> 072 * This trainer uses the native Tensorflow serialisation functionality and saves to a checkpoint on disk. It's much more 073 * fragile than the {@link TensorflowTrainer}. 074 * </p> 075 * <p> 076 * N.B. Tensorflow support is experimental and may change without a major version bump. 077 */ 078public final class TensorflowCheckpointTrainer<T extends Output<T>> implements Trainer<T> { 079 080 private static final Logger logger = Logger.getLogger(TensorflowCheckpointTrainer.class.getName()); 081 082 public static final String MODEL_FILENAME = "model"; 083 084 @Config(mandatory=true,description="Path to the protobuf containing the graph.") 085 private Path graphPath; 086 087 private byte[] graphDef; 088 089 @Config(mandatory=true,description="Feature extractor.") 090 private ExampleTransformer<T> exampleTransformer; 091 092 @Config(mandatory=true,description="Response extractor.") 093 private OutputTransformer<T> outputTransformer; 094 095 @Config(description="Minibatch size.") 096 private int minibatchSize = 1; 097 098 @Config(description="Number of SGD epochs to run.") 099 private int epochs = 5; 100 101 @Config(description="Logging interval to print out the loss.") 102 private int loggingInterval = 100; 103 104 @Config(description="Path to write out the checkpoints.") 105 private Path checkpointRootPath = Paths.get("/tmp/"); 106 107 private int trainInvocationCounter = 0; 108 109 /** 110 * for olcut. 111 */ 112 private TensorflowCheckpointTrainer() {} 113 114 /** 115 * Builds a trainer using the supplied graph and arguments. 116 * @param graphPath The graph to load. 117 * @param checkpointRootPath The checkpoint path to save to. 118 * @param exampleTransformer The feature transformer. 119 * @param outputTransformer The output transformer. 120 * @param minibatchSize The training batch size. 121 * @param epochs The number of training epochs. 122 * @throws IOException If the graph failed to load. 123 */ 124 public TensorflowCheckpointTrainer(Path graphPath, Path checkpointRootPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs) throws IOException { 125 this.graphPath = graphPath; 126 this.checkpointRootPath = checkpointRootPath; 127 this.exampleTransformer = exampleTransformer; 128 this.outputTransformer = outputTransformer; 129 this.minibatchSize = minibatchSize; 130 this.epochs = epochs; 131 postConfig(); 132 } 133 134 /** 135 * Used by the OLCUT configuration system, and should not be called by external code. 136 */ 137 @Override 138 public void postConfig() throws IOException { 139 graphDef = Files.readAllBytes(graphPath); 140 } 141 142 @Override 143 public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 144 Path checkpointPath; 145 try { 146 checkpointPath = Files.createTempDirectory(checkpointRootPath,"tensorflow-checkpoint"); 147 } catch (IOException e) { 148 logger.log(Level.SEVERE, "Failed to create checkpoint directory at path " + checkpointRootPath,e); 149 throw new IllegalStateException("Failed to create checkpoint directory at path " + checkpointRootPath,e); 150 } 151 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 152 ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo(); 153 ArrayList<Example<T>> batch = new ArrayList<>(); 154 155 trainInvocationCounter++; 156 157 try (Graph graph = new Graph(); 158 Session session = new Session(graph); 159 Tensor<?> isTraining = Tensor.create(true); 160 Tensor<String> checkpointPathTensor = Tensors.create(checkpointPath.toString()+"/"+MODEL_FILENAME) ) { 161 // Load in the graph definition 162 graph.importGraphDef(graphDef); 163 164 // Initialises the parameters. 165 session.runner().addTarget(TensorflowTrainer.INIT).run(); 166 logger.info("Initialised the model parameters"); 167 168 int interval = 0; 169 for (int i = 0; i < epochs; i++) { 170 logger.log(Level.INFO,"Starting epoch " + i); 171 Tensor<?> epoch = Tensor.create(i); 172 for (int j = 0; j < examples.size(); j += minibatchSize) { 173 batch.clear(); 174 for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) { 175 batch.add(examples.getExample(k)); 176 } 177 //logger.info("Batch = " + batch.size()); 178 Tensor<?> input = exampleTransformer.transform(batch,featureMap); 179 Tensor<?> target = outputTransformer.transform(batch,outputInfo); 180 Tensor<?> loss = session.runner() 181 .feed(TensorflowModel.INPUT_NAME, input) 182 .feed(TensorflowTrainer.TARGET, target) 183 .feed(TensorflowTrainer.EPOCH, epoch) 184 .feed(TensorflowTrainer.IS_TRAINING, isTraining) 185 .addTarget(TensorflowTrainer.TRAIN) 186 .fetch(TensorflowTrainer.TRAINING_LOSS) 187 .run().get(0); 188 if (interval % loggingInterval == 0) { 189 logger.log(Level.INFO, "Training loss = " + loss.floatValue()); 190 } 191 input.close(); 192 target.close(); 193 loss.close(); 194 interval++; 195 } 196 epoch.close(); 197 } 198 199 session.runner().feed("save/Const", checkpointPathTensor).addTarget("save/control_dependency").run(); 200 201 byte[] trainedGraphDef = graph.toGraphDef(); 202 203 ModelProvenance modelProvenance = new ModelProvenance(TensorflowCheckpointModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance); 204 TensorflowCheckpointModel<T> tfModel = new TensorflowCheckpointModel<>("tf-model", modelProvenance, featureMap, 205 outputInfo, trainedGraphDef, checkpointPath.toString(), exampleTransformer, outputTransformer); 206 207 return tfModel; 208 } catch (TensorFlowException e) { 209 logger.log(Level.SEVERE, "TensorFlow threw an error", e); 210 throw new IllegalStateException(e); 211 } 212 } 213 214 @Override 215 public int getInvocationCount() { 216 return trainInvocationCounter; 217 } 218 219 @Override 220 public String toString() { 221 return "TensorflowCheckpointTrainer(graphPath="+graphPath.toString() 222 +",checkpointRootPath="+checkpointRootPath.toString()+",exampleTransformer=" 223 +exampleTransformer.toString()+",outputTransformer"+outputTransformer.toString() 224 +",minibatchSize="+ minibatchSize +",epochs="+ epochs +")"; 225 } 226 227 @Override 228 public TrainerProvenance getProvenance() { 229 return new TensorflowCheckpointTrainerProvenance(this); 230 } 231 232 public static final class TensorflowCheckpointTrainerProvenance extends SkeletalTrainerProvenance { 233 private static final long serialVersionUID = 1L; 234 235 public static final String GRAPH_HASH = "graph-hash"; 236 public static final String GRAPH_LAST_MOD = "graph-last-modified"; 237 238 private final HashProvenance graphHash; 239 private final DateTimeProvenance graphLastModified; 240 241 <T extends Output<T>> TensorflowCheckpointTrainerProvenance(TensorflowCheckpointTrainer<T> host) { 242 super(host); 243 // instance parameters 244 this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath)); 245 this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault())); 246 } 247 248 public TensorflowCheckpointTrainerProvenance(Map<String,Provenance> map) { 249 this(extractTFProvenanceInfo(map)); 250 } 251 252 private TensorflowCheckpointTrainerProvenance(ExtractedInfo info) { 253 super(info); 254 this.graphHash = (HashProvenance) info.instanceValues.get(GRAPH_HASH); 255 this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD); 256 } 257 258 @Override 259 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 260 Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues(); 261 262 map.put(graphHash.getKey(),graphHash); 263 map.put(graphLastModified.getKey(),graphLastModified); 264 265 return map; 266 } 267 268 protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) { 269 ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map); 270 info.instanceValues.put(GRAPH_HASH, ObjectProvenance.checkAndExtractProvenance(map,GRAPH_HASH,HashProvenance.class, TensorflowTrainerProvenance.class.getSimpleName())); 271 info.instanceValues.put(GRAPH_LAST_MOD,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_LAST_MOD,DateTimeProvenance.class,TensorflowTrainerProvenance.class.getSimpleName())); 272 return info; 273 } 274 } 275}