001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow.sequence; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 027import org.tribuo.ImmutableFeatureMap; 028import org.tribuo.ImmutableOutputInfo; 029import org.tribuo.Output; 030import org.tribuo.interop.tensorflow.TensorflowUtil; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.provenance.SkeletalTrainerProvenance; 033import org.tribuo.provenance.TrainerProvenance; 034import org.tribuo.sequence.SequenceDataset; 035import org.tribuo.sequence.SequenceExample; 036import org.tribuo.sequence.SequenceModel; 037import org.tribuo.sequence.SequenceTrainer; 038import org.tribuo.util.Util; 039import org.tensorflow.Graph; 040import org.tensorflow.Session; 041import org.tensorflow.Tensor; 042import org.tensorflow.TensorFlowException; 043 044import java.io.IOException; 045import java.nio.file.Files; 046import java.nio.file.Path; 047import java.time.Instant; 048import java.time.OffsetDateTime; 049import java.time.ZoneId; 050import java.util.ArrayList; 051import java.util.Collections; 052import java.util.Map; 053import java.util.SplittableRandom; 054import java.util.logging.Level; 055import java.util.logging.Logger; 056 057/** 058 * A trainer for SequenceModels which use an underlying Tensorflow graph. 059 */ 060public class TensorflowSequenceTrainer<T extends Output<T>> implements SequenceTrainer<T> { 061 062 private static final Logger log = Logger.getLogger(TensorflowSequenceTrainer.class.getName()); 063 064 @Config(mandatory=true,description="Path to the protobuf containing the Tensorflow graph.") 065 protected Path graphPath; 066 067 private byte[] graphDef; 068 069 @Config(mandatory=true,description="Sequence feature extractor.") 070 protected SequenceExampleTransformer<T> exampleTransformer; 071 @Config(mandatory=true,description="Sequence output extractor.") 072 protected SequenceOutputTransformer<T> outputTransformer; 073 074 @Config(description="Minibatch size") 075 protected int minibatchSize = 1; 076 @Config(description="Number of SGD epochs to run.") 077 protected int epochs = 5; 078 @Config(description="Logging interval to print the loss.") 079 protected int loggingInterval = 100; 080 @Config(description="Seed for the RNG.") 081 protected long seed = 1; 082 083 @Config(mandatory=true,description="Name of the initialisation operation.") 084 protected String initOp; 085 @Config(mandatory=true,description="Name of the training operation.") 086 protected String trainOp; 087 @Config(mandatory=true,description="Name of the loss operation (to inspect the loss).") 088 protected String getLossOp; 089 @Config(mandatory=true,description="Name of the prediction operation.") 090 protected String predictOp; 091 092 protected SplittableRandom rng; 093 094 protected int trainInvocationCounter; 095 096 public TensorflowSequenceTrainer(Path graphPath, 097 SequenceExampleTransformer<T> exampleTransformer, 098 SequenceOutputTransformer<T> outputTransformer, 099 int minibatchSize, 100 int epochs, 101 int loggingInterval, 102 long seed, 103 String initOp, 104 String trainOp, 105 String getLossOp, 106 String predictOp) throws IOException { 107 this.graphPath = graphPath; 108 this.exampleTransformer = exampleTransformer; 109 this.outputTransformer = outputTransformer; 110 this.minibatchSize = minibatchSize; 111 this.epochs = epochs; 112 this.loggingInterval = loggingInterval; 113 this.seed = seed; 114 this.initOp = initOp; 115 this.trainOp = trainOp; 116 this.getLossOp = getLossOp; 117 this.predictOp = predictOp; 118 postConfig(); 119 } 120 121 /** Constructor required by olcut config system. **/ 122 private TensorflowSequenceTrainer() { } 123 124 @Override 125 public synchronized void postConfig() throws IOException { 126 rng = new SplittableRandom(seed); 127 graphDef = Files.readAllBytes(graphPath); 128 } 129 130 @Override 131 public SequenceModel<T> train(SequenceDataset<T> examples, Map<String,Provenance> runProvenance) { 132 // Creates a new RNG, adds one to the invocation count. 133 SplittableRandom localRNG; 134 TrainerProvenance provenance; 135 synchronized(this) { 136 localRNG = rng.split(); 137 provenance = getProvenance(); 138 trainInvocationCounter++; 139 } 140 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 141 ImmutableOutputInfo<T> labelMap = examples.getOutputIDInfo(); 142 ArrayList<SequenceExample<T>> batch = new ArrayList<>(); 143 144 int[] indices = Util.randperm(examples.size(), localRNG); 145 146 try (Graph graph = new Graph(); 147 Session session = new Session(graph)) { 148 // 149 // Load the graph def into the session. 150 graph.importGraphDef(graphDef); 151 // 152 // Initialise the variables. 153 session.runner().addTarget(initOp).run(); 154 log.info("Initialised the model parameters"); 155 // 156 // Run additional initialization routines, if needed. 157 preTrainingHook(session, examples); 158 159 int interval = 0; 160 for (int i = 0; i < epochs; i++) { 161 log.log(Level.INFO,"Starting epoch " + i); 162 163 // Shuffle the order in which we'll look at examples 164 Util.randpermInPlace(indices, localRNG); 165 166 for (int j = 0; j < examples.size(); j += minibatchSize) { 167 batch.clear(); 168 for (int k = j; k < (j+ minibatchSize) && k < examples.size(); k++) { 169 int ix = indices[k]; 170 batch.add(examples.getExample(ix)); 171 } 172 // 173 // Transform examples to tensors 174 Map<String, Tensor<?>> feed = exampleTransformer.encode(batch, featureMap); 175 // 176 // Add supervision 177 feed.putAll(outputTransformer.encode(batch, labelMap)); 178 // 179 // Add any additional training hyperparameter values to the feed dict. 180 feed.putAll(getHyperparameterFeed()); 181 // 182 // Populate the runner. 183 Session.Runner runner = session.runner(); 184 for (Map.Entry<String, Tensor<?>> item : feed.entrySet()) { 185 runner.feed(item.getKey(), item.getValue()); 186 } 187 // 188 // Run a training batch. 189 Tensor<?> loss = runner 190 .addTarget(trainOp) 191 .fetch(getLossOp) 192 .run() 193 .get(0); 194 if (interval % loggingInterval == 0) { 195 log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]", 196 loss.floatValue(), i, interval, j, Math.min(examples.size(), j+minibatchSize), examples.size())); 197 } 198 interval++; 199 // 200 // Cleanup: close the tensors. 201 loss.close(); 202 for (Tensor<?> tns : feed.values()) { 203 tns.close(); 204 } 205 } 206 } 207 208 // This call **must** happen before the trainedGraphDef is generated. 209 TensorflowUtil.annotateGraph(graph,session); 210 // 211 // Generate the trained graph def. 212 byte[] trainedGraphDef = graph.toGraphDef(); 213 Map<String,Object> tensorMap = TensorflowUtil.serialise(graph,session); 214 ModelProvenance modelProvenance = new ModelProvenance(TensorflowSequenceModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), provenance, runProvenance); 215 return new TensorflowSequenceModel<>( 216 "tf-sequence-model", 217 modelProvenance, 218 featureMap, 219 labelMap, 220 trainedGraphDef, 221 exampleTransformer, 222 outputTransformer, 223 initOp, 224 predictOp, 225 tensorMap 226 ); 227 228 } catch (TensorFlowException e) { 229 log.log(Level.SEVERE, "TensorFlow threw an error", e); 230 throw new IllegalStateException(e); 231 } 232 } 233 234 @Override 235 public int getInvocationCount() { 236 return trainInvocationCounter; 237 } 238 239 @Override 240 public String toString() { 241 return "TensorflowSequenceTrainer(graphPath="+graphPath.toString()+",exampleTransformer=" 242 +exampleTransformer.toString()+",outputTransformer="+outputTransformer.toString() 243 +",minibatchSize="+ minibatchSize +",epochs="+ epochs +",seed="+seed+")"; 244 } 245 246 protected void preTrainingHook(Session session, SequenceDataset<T> examples) {} 247 248 protected Map<String, Tensor<?>> getHyperparameterFeed() { 249 return Collections.emptyMap(); 250 } 251 252 @Override 253 public TrainerProvenance getProvenance() { 254 return new TensorflowSequenceTrainerProvenance(this); 255 } 256 257 public static class TensorflowSequenceTrainerProvenance extends SkeletalTrainerProvenance { 258 private static final long serialVersionUID = 1L; 259 260 public static final String GRAPH_HASH = "graph-hash"; 261 public static final String GRAPH_LAST_MOD = "graph-last-modified"; 262 263 private final StringProvenance graphHash; 264 private final DateTimeProvenance graphLastModified; 265 266 <T extends Output<T>> TensorflowSequenceTrainerProvenance(TensorflowSequenceTrainer<T> host) { 267 super(host); 268 // instance parameters 269 this.graphHash = new StringProvenance(GRAPH_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.graphPath)); 270 this.graphLastModified = new DateTimeProvenance(GRAPH_LAST_MOD,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.graphPath.toFile().lastModified()), ZoneId.systemDefault())); 271 } 272 273 public TensorflowSequenceTrainerProvenance(Map<String,Provenance> map) { 274 this(extractTFProvenanceInfo(map)); 275 } 276 277 private TensorflowSequenceTrainerProvenance(ExtractedInfo info) { 278 super(info); 279 this.graphHash = (StringProvenance) info.instanceValues.get(GRAPH_HASH); 280 this.graphLastModified = (DateTimeProvenance) info.instanceValues.get(GRAPH_LAST_MOD); 281 } 282 283 @Override 284 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 285 Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues(); 286 287 map.put(graphHash.getKey(),graphHash); 288 map.put(graphLastModified.getKey(),graphLastModified); 289 290 return map; 291 } 292 293 protected static ExtractedInfo extractTFProvenanceInfo(Map<String,Provenance> map) { 294 ExtractedInfo info = SkeletalTrainerProvenance.extractProvenanceInfo(map); 295 296 if (info.configuredParameters.containsKey(GRAPH_HASH)) { 297 Provenance tmpProv = info.configuredParameters.remove(GRAPH_HASH); 298 if (tmpProv instanceof HashProvenance) { 299 info.instanceValues.put(GRAPH_HASH,(HashProvenance) tmpProv); 300 } else { 301 throw new ProvenanceException(GRAPH_HASH + " was not of type HashProvenance in class " + info.className); 302 } 303 } else { 304 throw new ProvenanceException("Failed to find " + GRAPH_HASH + " when constructing SkeletalTrainerProvenance"); 305 } 306 if (info.configuredParameters.containsKey(GRAPH_LAST_MOD)) { 307 Provenance tmpProv = info.configuredParameters.remove(GRAPH_LAST_MOD); 308 if (tmpProv instanceof DateTimeProvenance) { 309 info.instanceValues.put(GRAPH_LAST_MOD,(DateTimeProvenance) tmpProv); 310 } else { 311 throw new ProvenanceException(GRAPH_LAST_MOD + " was not of type DateTimeProvenance in class " + info.className); 312 } 313 } else { 314 throw new ProvenanceException("Failed to find " + GRAPH_LAST_MOD + " when constructing SkeletalTrainerProvenance"); 315 } 316 317 return info; 318 } 319 } 320}