001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 026import org.tribuo.Dataset; 027import org.tribuo.Example; 028import org.tribuo.ImmutableFeatureMap; 029import org.tribuo.ImmutableOutputInfo; 030import org.tribuo.Model; 031import org.tribuo.Output; 032import org.tribuo.Trainer; 033import org.tribuo.provenance.ModelProvenance; 034import org.tribuo.provenance.SkeletalTrainerProvenance; 035import org.tribuo.provenance.TrainerProvenance; 036import org.tensorflow.Graph; 037import org.tensorflow.Session; 038import org.tensorflow.Tensor; 039import org.tensorflow.TensorFlowException; 040 041import java.io.IOException; 042import java.nio.file.Files; 043import java.nio.file.Path; 044import java.time.Instant; 045import java.time.OffsetDateTime; 046import java.time.ZoneId; 047import java.util.ArrayList; 048import java.util.Map; 049import java.util.logging.Level; 050import java.util.logging.Logger; 051 052/** 053 * Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and 054 * targets listed below. 055 * 056 * <ul> 057 * <li>{@link TensorflowModel#INPUT_NAME} - the input minibatch.</li> 058 * <li>{@link TensorflowModel#OUTPUT_NAME} - the predicted output.</li> 059 * <li>{@link TensorflowTrainer#TARGET} - the output to predict.</li> 060 * <li>{@link TensorflowTrainer#TRAIN} - the train function to run (usually a single step of SGD).</li> 061 * <li>{@link TensorflowTrainer#TRAINING_LOSS} - the loss tensor to extract for logging.</li> 062 * <li>{@link TensorflowTrainer#EPOCH} - the current epoch number, used for gradient scaling.</li> 063 * <li>{@link TensorflowTrainer#IS_TRAINING} - a boolean placeholder to turn on dropout or other training specific functionality.</li> 064 * <li>{@link TensorflowTrainer#INIT} - the function to initialise the graph.</li> 065 * </ul> 066 * 067 * This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1. 068 * <p> 069 * This trainer uses the serialisation functionality in {@link TensorflowUtil}, as opposed to a SavedModel or a checkpoint. 070 * <p> 071 * N.B. Tensorflow support is experimental and may change without a major version bump. 072 */ 073public final class TensorflowTrainer<T extends Output<T>> implements Trainer<T> { 074 075 private static final Logger logger = Logger.getLogger(TensorflowTrainer.class.getName()); 076 077 public static final String TARGET = "target"; 078 public static final String TRAIN = "train"; 079 public static final String TRAINING_LOSS = "training_loss"; 080 public static final String EPOCH = "epoch"; 081 public static final String IS_TRAINING = "is_training"; 082 public static final String INIT = "init"; 083 084 @Config(mandatory=true,description="Path to the protobuf containing the graph.") 085 private Path graphPath; 086 087 private byte[] graphDef; 088 089 @Config(description="Test time batch size.") 090 private int testBatchSize = 16; 091 092 @Config(mandatory=true,description="Feature extractor.") 093 private ExampleTransformer<T> exampleTransformer; 094 095 @Config(mandatory=true,description="Response extractor.") 096 private OutputTransformer<T> outputTransformer; 097 098 @Config(description="Minibatch size.") 099 private int minibatchSize = 1; 100 101 @Config(description="Number of SGD epochs to run.") 102 private int epochs = 5; 103 104 @Config(description="Logging interval to print out the loss.") 105 private int loggingInterval = 100; 106 107 private int trainInvocationCounter = 0; 108 109 /** 110 * for olcut. 111 */ 112 private TensorflowTrainer() {} 113 114 /** 115 * Constructs a Trainer for a tensorflow graph. 116 * @param graphPath The path to the graph protobuf. Must have the targets and placeholders specified above. 117 * @param exampleTransformer The example transformer to convert a Tribuo {@link Example} into a {@link Tensor}. 118 * @param outputTransformer The output transformer to convert a Tribuo {@link Output} into a {@link Tensor} and back. This encodes the output type. 119 * @param minibatchSize The minibatch size to use in training. 120 * @param epochs The number of SGD epochs to run. 121 * @param testBatchSize The minibatch size to use at test time. 122 * @throws IOException If the graphPath is invalid or failed to load. 123 */ 124 public TensorflowTrainer(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) throws IOException { 125 this.graphPath = graphPath; 126 this.exampleTransformer = exampleTransformer; 127 this.outputTransformer = outputTransformer; 128 this.minibatchSize = minibatchSize; 129 this.epochs = epochs; 130 this.testBatchSize = testBatchSize; 131 postConfig(); 132 } 133 134 /** 135 * Constructs a Trainer for a tensorflow graph. 136 * @param graphDef The graph definition as a byte array. Must have the targets and placeholders specified above. 137 * @param exampleTransformer The example transformer to convert a Tribuo {@link Example} into a {@link Tensor}. 138 * @param outputTransformer The output transformer to convert a Tribuo {@link Output} into a {@link Tensor} and back. This encodes the output type. 139 * @param minibatchSize The minibatch size to use in training. 140 * @param epochs The number of SGD epochs to run. 141 * @param testBatchSize The minibatch size to use at test time. 142 */ 143 public TensorflowTrainer(byte[] graphDef, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) { 144 this.graphPath = null; 145 this.graphDef = graphDef; 146 this.exampleTransformer = exampleTransformer; 147 this.outputTransformer = outputTransformer; 148 this.minibatchSize = minibatchSize; 149 this.epochs = epochs; 150 this.testBatchSize = testBatchSize; 151 } 152 153 /** 154 * Used by the OLCUT configuration system, and should not be called by external code. 155 */ 156 @Override 157 public void postConfig() throws IOException { 158 graphDef = Files.readAllBytes(graphPath); 159 } 160 161 @Override 162 public Model<T> train(Dataset<T> examples, Map<String,Provenance> runProvenance) { 163 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 164 ImmutableOutputInfo<T> outputInfo = examples.getOutputIDInfo(); 165 ArrayList<Example<T>> batch = new ArrayList<>(); 166 trainInvocationCounter++; 167 168 try (Graph graph = new Graph(); 169 Session session = new Session(graph); 170 Tensor<?> isTraining = Tensor.create(true)) { 171 // Load in the graph definition 172 graph.importGraphDef(graphDef); 173 174 // Initialises the parameters. 175 session.runner().addTarget(INIT).run(); 176 logger.info("Initialised the model parameters"); 177 178 int interval = 0; 179 for (int i = 0; i < epochs; i++) { 180 logger.log(Level.INFO,"Starting epoch " + i); 181 Tensor<?> epoch = Tensor.create(i); 182 for (int j = 0; j < examples.size(); j += minibatchSize) { 183 batch.clear(); 184 for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) { 185 batch.add(examples.getExample(k)); 186 } 187 //logger.info("Batch = " + batch.size()); 188 Tensor<?> input = exampleTransformer.transform(batch,featureMap); 189 Tensor<?> target = outputTransformer.transform(batch,outputInfo); 190 Tensor<?> loss = session.runner() 191 .feed(TensorflowModel.INPUT_NAME, input) 192 .feed(TARGET, target) 193 .feed(EPOCH, epoch) 194 .feed(IS_TRAINING, isTraining) 195 .addTarget(TRAIN) 196 .fetch(TRAINING_LOSS) 197 .run().get(0); 198 if (interval % loggingInterval == 0) { 199 logger.log(Level.INFO, "Training loss = " + loss.floatValue()); 200 } 201 input.close(); 202 target.close(); 203 loss.close(); 204 interval++; 205 } 206 epoch.close(); 207 } 208 209 //System.out.println("After training"); 210 //TensorflowModel.print(session); 211 212 // This call **must** happen before the trainedGraphDef is generated. 213 TensorflowUtil.annotateGraph(graph,session); 214 215 byte[] trainedGraphDef = graph.toGraphDef(); 216 217 Map<String,Object> tensorMap = TensorflowUtil.serialise(graph,session); 218 219 ModelProvenance modelProvenance = new ModelProvenance(TensorflowModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance); 220 TensorflowModel<T> tfModel = new TensorflowModel<>("tf-model", modelProvenance, featureMap, 221 outputInfo, trainedGraphDef, tensorMap, testBatchSize, exampleTransformer, outputTransformer); 222 223 return tfModel; 224 } catch (TensorFlowException e) { 225 logger.log(Level.SEVERE, "TensorFlow threw an error", e); 226 throw new IllegalStateException(e); 227 } 228 } 229 230 @Override 231 public String toString() { 232 String path = graphPath==null?"":graphPath.toString(); 233 return "TensorflowTrainer(graphPath="+path+",exampleTransformer=" 234 +exampleTransformer.toString()+",outputTransformer="+outputTransformer.toString() 235 +",minibatchSize="+ minibatchSize +",epochs="+ epochs +")"; 236 } 237 238 @Override 239 public int getInvocationCount() { 240 return trainInvocationCounter; 241 } 242 243 @Override 244 public TrainerProvenance getProvenance() { 245 return new TensorflowTrainerProvenance(this); 246 } 247 248 public static final class TensorflowTrainerProvenance extends SkeletalTrainerProvenance { 249 private static final long serialVersionUID = 1L; 250 251 public static final String GRAPH_HASH = "graph-hash"; 252 public static final String GRAPH_LAST_MOD = "graph-last-modified"; 253 254 private final HashProvenance graphHash; 255 private final DateTimeProvenance graphLastModified; 256 257 <T extends Output<T>> TensorflowTrainerProvenance(TensorflowTrainer<T> host) { 258 super(host); 259 // instance parameters 260 if (host.graphPath != null) { 261 this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath)); 262 this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault())); 263 } else { 264 this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE,GRAPH_HASH,hashArray(DEFAULT_HASH_TYPE,host.graphDef)); 265 this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD, OffsetDateTime.now()); 266 } 267 } 268 269 public TensorflowTrainerProvenance(Map<String,Provenance> map) { 270 this(extractTFProvenanceInfo(map)); 271 } 272 273 private TensorflowTrainerProvenance(ExtractedInfo info) { 274 super(info); 275 this.graphHash = (HashProvenance) info.instanceValues.get(GRAPH_HASH); 276 this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD); 277 } 278 279 /** 280 * Hashes a byte array using the specified {@link com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil.HashType}. 281 * @param hashType The type of hash to perform. 282 * @param input The input array. 283 * @return A hexadecimal string representation of the hash. 284 */ 285 private static String hashArray(com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil.HashType hashType, byte[] input) { 286 java.security.MessageDigest md = hashType.getDigest(); 287 md.update(input); 288 return ProvenanceUtil.bytesToHexString(md.digest()); 289 } 290 291 @Override 292 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 293 Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues(); 294 295 map.put(graphHash.getKey(),graphHash); 296 map.put(graphLastModified.getKey(),graphLastModified); 297 298 return map; 299 } 300 301 protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) { 302 ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map); 303 info.instanceValues.put(GRAPH_HASH,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_HASH,HashProvenance.class,TensorflowTrainerProvenance.class.getSimpleName())); 304 info.instanceValues.put(GRAPH_LAST_MOD,ObjectProvenance.checkAndExtractProvenance(map,GRAPH_LAST_MOD,DateTimeProvenance.class,TensorflowTrainerProvenance.class.getSimpleName())); 305 return info; 306 } 307 } 308}