public class TensorflowSequenceTrainer<T extends Output<T>> extends Object implements SequenceTrainer<T>
Modifier and Type | Class and Description |
---|---|
static class |
TensorflowSequenceTrainer.TensorflowSequenceTrainerProvenance |
Modifier and Type | Field and Description |
---|---|
protected int |
epochs |
protected SequenceExampleTransformer<T> |
exampleTransformer |
protected String |
getLossOp |
protected Path |
graphPath |
protected String |
initOp |
protected int |
loggingInterval |
protected int |
minibatchSize |
protected SequenceOutputTransformer<T> |
outputTransformer |
protected String |
predictOp |
protected SplittableRandom |
rng |
protected long |
seed |
protected int |
trainInvocationCounter |
protected String |
trainOp |
Constructor and Description |
---|
TensorflowSequenceTrainer(Path graphPath,
SequenceExampleTransformer<T> exampleTransformer,
SequenceOutputTransformer<T> outputTransformer,
int minibatchSize,
int epochs,
int loggingInterval,
long seed,
String initOp,
String trainOp,
String getLossOp,
String predictOp) |
Modifier and Type | Method and Description |
---|---|
protected Map<String,org.tensorflow.Tensor<?>> |
getHyperparameterFeed() |
int |
getInvocationCount()
Returns the number of times the train method has been invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig() |
protected void |
preTrainingHook(org.tensorflow.Session session,
SequenceDataset<T> examples) |
String |
toString() |
SequenceModel<T> |
train(SequenceDataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sequence prediction model using the examples in the given data set.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
train
@Config(mandatory=true, description="Path to the protobuf containing the Tensorflow graph.") protected Path graphPath
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceExampleTransformer<T extends Output<T>> exampleTransformer
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputTransformer<T extends Output<T>> outputTransformer
@Config(description="Minibatch size") protected int minibatchSize
@Config(description="Number of SGD epochs to run.") protected int epochs
@Config(description="Logging interval to print the loss.") protected int loggingInterval
@Config(description="Seed for the RNG.") protected long seed
@Config(mandatory=true, description="Name of the initialisation operation.") protected String initOp
@Config(mandatory=true, description="Name of the training operation.") protected String trainOp
@Config(mandatory=true, description="Name of the loss operation (to inspect the loss).") protected String getLossOp
@Config(mandatory=true, description="Name of the prediction operation.") protected String predictOp
protected SplittableRandom rng
protected int trainInvocationCounter
public TensorflowSequenceTrainer(Path graphPath, SequenceExampleTransformer<T> exampleTransformer, SequenceOutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int loggingInterval, long seed, String initOp, String trainOp, String getLossOp, String predictOp) throws IOException
IOException
public void postConfig() throws IOException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
IOException
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
SequenceTrainer
train
in interface SequenceTrainer<T extends Output<T>>
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).public int getInvocationCount()
SequenceTrainer
getInvocationCount
in interface SequenceTrainer<T extends Output<T>>
protected void preTrainingHook(org.tensorflow.Session session, SequenceDataset<T> examples)
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.