public class TensorFlowSequenceTrainer<T extends Output<T>> extends Object implements SequenceTrainer<T>
N.B. TensorFlow support is experimental and may change without a major version bump.
Modifier and Type | Class and Description |
---|---|
static class |
TensorFlowSequenceTrainer.TensorFlowSequenceTrainerProvenance |
Modifier and Type | Field and Description |
---|---|
protected int |
epochs |
protected SequenceFeatureConverter |
featureConverter |
protected String |
getLossOp |
protected Path |
graphPath |
protected int |
loggingInterval |
protected int |
minibatchSize |
protected SequenceOutputConverter<T> |
outputConverter |
protected String |
predictOp |
protected SplittableRandom |
rng |
protected long |
seed |
protected int |
trainInvocationCounter |
protected String |
trainOp |
Constructor and Description |
---|
TensorFlowSequenceTrainer(Path graphPath,
SequenceFeatureConverter featureConverter,
SequenceOutputConverter<T> outputConverter,
int minibatchSize,
int epochs,
int loggingInterval,
long seed,
String trainOp,
String getLossOp,
String predictOp)
Constructs a TensorFlowSequenceTrainer using the specified parameters.
|
Modifier and Type | Method and Description |
---|---|
protected TensorMap |
getHyperparameterFeed()
Build any necessary non-data parameter tensors.
|
int |
getInvocationCount()
Returns the number of times the train method has been invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
protected void |
preTrainingHook(org.tensorflow.Session session,
SequenceDataset<T> examples)
A hook for modifying the session state before training starts.
|
String |
toString() |
SequenceModel<T> |
train(SequenceDataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sequence prediction model using the examples in the given data set.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
train
@Config(mandatory=true, description="Path to the protobuf containing the TensorFlow graph.") protected Path graphPath
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceFeatureConverter featureConverter
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputConverter<T extends Output<T>> outputConverter
@Config(description="Minibatch size") protected int minibatchSize
@Config(description="Number of SGD epochs to run.") protected int epochs
@Config(description="Logging interval to print the loss.") protected int loggingInterval
@Config(description="Seed for the RNG.") protected long seed
@Config(mandatory=true, description="Name of the training operation.") protected String trainOp
@Config(mandatory=true, description="Name of the loss operation (to inspect the loss).") protected String getLossOp
@Config(mandatory=true, description="Name of the prediction operation.") protected String predictOp
protected SplittableRandom rng
protected int trainInvocationCounter
public TensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) throws IOException
graphPath
- The path to the TF graph.featureConverter
- The feature conversion object.outputConverter
- The output conversion object.minibatchSize
- The training minibatch size.epochs
- The number of training epochs.loggingInterval
- The logging interval.seed
- The RNG seed.trainOp
- The name of the training operation.getLossOp
- The name of the loss operation.predictOp
- The name of the prediction operation.IOException
- If the graph can not be read from disk.public void postConfig() throws IOException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
IOException
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
SequenceTrainer
train
in interface SequenceTrainer<T extends Output<T>>
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).public int getInvocationCount()
SequenceTrainer
getInvocationCount
in interface SequenceTrainer<T extends Output<T>>
protected void preTrainingHook(org.tensorflow.Session session, SequenceDataset<T> examples)
This should not mutate any state in the trainer.
session
- The session to modify.examples
- The dataset.protected TensorMap getHyperparameterFeed()
The default implementation returns an empty map, and should be overridden if necessary.
This should not mutate any state in the trainer.
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.