Class TensorFlowSequenceTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.sequence.TensorFlowSequenceTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,SequenceTrainer<T>
public class TensorFlowSequenceTrainer<T extends Output<T>>
extends Object
implements SequenceTrainer<T>
A trainer for SequenceModels which use an underlying TensorFlow graph.
N.B. TensorFlow support is experimental and may change without a major version bump.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic class -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected intprotected SequenceFeatureConverterprotected Stringprotected Pathprotected intprotected intprotected SequenceOutputConverter<T> protected Stringprotected SplittableRandomprotected longprotected intprotected String -
Constructor Summary
ConstructorsConstructorDescriptionTensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) Constructs a TensorFlowSequenceTrainer using the specified parameters. -
Method Summary
Modifier and TypeMethodDescriptionprotected TensorMapBuild any necessary non-data parameter tensors.intReturns the number of times the train method has been invoked.voidUsed by the OLCUT configuration system, and should not be called by external code.protected voidpreTrainingHook(org.tensorflow.Session session, SequenceDataset<T> examples) A hook for modifying the session state before training starts.toString()train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Field Details
-
graphPath
@Config(mandatory=true, description="Path to the protobuf containing the TensorFlow graph.") protected Path graphPath -
featureConverter
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceFeatureConverter featureConverter -
outputConverter
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputConverter<T extends Output<T>> outputConverter -
minibatchSize
@Config(description="Minibatch size") protected int minibatchSize -
epochs
@Config(description="Number of SGD epochs to run.") protected int epochs -
loggingInterval
@Config(description="Logging interval to print the loss.") protected int loggingInterval -
seed
@Config(description="Seed for the RNG.") protected long seed -
trainOp
-
getLossOp
@Config(mandatory=true, description="Name of the loss operation (to inspect the loss).") protected String getLossOp -
predictOp
-
rng
-
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
TensorFlowSequenceTrainer
public TensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) throws IOException Constructs a TensorFlowSequenceTrainer using the specified parameters.- Parameters:
graphPath- The path to the TF graph.featureConverter- The feature conversion object.outputConverter- The output conversion object.minibatchSize- The training minibatch size.epochs- The number of training epochs.loggingInterval- The logging interval.seed- The RNG seed.trainOp- The name of the training operation.getLossOp- The name of the loss operation.predictOp- The name of the prediction operation.- Throws:
IOException- If the graph can not be read from disk.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable- Throws:
IOException
-
train
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainerTrains a sequence prediction model using the examples in the given data set.- Specified by:
trainin interfaceSequenceTrainer<T extends Output<T>>- Parameters:
examples- the data set containing the examples.runProvenance- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:SequenceTrainerReturns the number of times the train method has been invoked.- Specified by:
getInvocationCountin interfaceSequenceTrainer<T extends Output<T>>- Returns:
- The number of times train has been invoked.
-
toString
-
preTrainingHook
A hook for modifying the session state before training starts.This should not mutate any state in the trainer.
- Parameters:
session- The session to modify.examples- The dataset.
-
getHyperparameterFeed
Build any necessary non-data parameter tensors.The default implementation returns an empty map, and should be overridden if necessary.
This should not mutate any state in the trainer.
- Returns:
- The parameter tensors.
-
getProvenance
-