Class TensorFlowSequenceTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.sequence.TensorFlowSequenceTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SequenceTrainer<T>
public class TensorFlowSequenceTrainer<T extends Output<T>>
extends Object
implements SequenceTrainer<T>
A trainer for SequenceModels which use an underlying TensorFlow graph.
N.B. TensorFlow support is experimental and may change without a major version bump.
-
Nested Class Summary
-
Field Summary
Modifier and TypeFieldDescriptionprotected int
protected SequenceFeatureConverter
protected String
protected Path
protected int
protected int
protected SequenceOutputConverter<T>
protected String
protected SplittableRandom
protected long
protected int
protected String
-
Constructor Summary
ConstructorDescriptionTensorFlowSequenceTrainer
(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) Constructs a TensorFlowSequenceTrainer using the specified parameters. -
Method Summary
Modifier and TypeMethodDescriptionprotected TensorMap
Build any necessary non-data parameter tensors.int
Returns the number of times the train method has been invoked.void
Used by the OLCUT configuration system, and should not be called by external code.protected void
preTrainingHook
(org.tensorflow.Session session, SequenceDataset<T> examples) A hook for modifying the session state before training starts.toString()
train
(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Field Details
-
graphPath
@Config(mandatory=true, description="Path to the protobuf containing the TensorFlow graph.") protected Path graphPath -
featureConverter
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceFeatureConverter featureConverter -
outputConverter
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputConverter<T extends Output<T>> outputConverter -
minibatchSize
@Config(description="Minibatch size") protected int minibatchSize -
epochs
@Config(description="Number of SGD epochs to run.") protected int epochs -
loggingInterval
@Config(description="Logging interval to print the loss.") protected int loggingInterval -
seed
@Config(description="Seed for the RNG.") protected long seed -
trainOp
-
getLossOp
@Config(mandatory=true, description="Name of the loss operation (to inspect the loss).") protected String getLossOp -
predictOp
-
rng
-
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
TensorFlowSequenceTrainer
public TensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) throws IOException Constructs a TensorFlowSequenceTrainer using the specified parameters.- Parameters:
graphPath
- The path to the TF graph.featureConverter
- The feature conversion object.outputConverter
- The output conversion object.minibatchSize
- The training minibatch size.epochs
- The number of training epochs.loggingInterval
- The logging interval.seed
- The RNG seed.trainOp
- The name of the training operation.getLossOp
- The name of the loss operation.predictOp
- The name of the prediction operation.- Throws:
IOException
- If the graph can not be read from disk.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
IOException
-
train
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainer
Trains a sequence prediction model using the examples in the given data set.- Specified by:
train
in interfaceSequenceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:SequenceTrainer
Returns the number of times the train method has been invoked.- Specified by:
getInvocationCount
in interfaceSequenceTrainer<T extends Output<T>>
- Returns:
- The number of times train has been invoked.
-
toString
-
preTrainingHook
A hook for modifying the session state before training starts.This should not mutate any state in the trainer.
- Parameters:
session
- The session to modify.examples
- The dataset.
-
getHyperparameterFeed
Build any necessary non-data parameter tensors.The default implementation returns an empty map, and should be overridden if necessary.
This should not mutate any state in the trainer.
- Returns:
- The parameter tensors.
-
getProvenance
-