Class TensorflowSequenceTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.sequence.TensorflowSequenceTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,SequenceTrainer<T>
public class TensorflowSequenceTrainer<T extends Output<T>>
extends Object
implements SequenceTrainer<T>
A trainer for SequenceModels which use an underlying Tensorflow graph.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic class -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected intprotected SequenceExampleTransformer<T> protected Stringprotected Pathprotected Stringprotected intprotected intprotected SequenceOutputTransformer<T> protected Stringprotected SplittableRandomprotected longprotected intprotected String -
Constructor Summary
ConstructorsConstructorDescriptionTensorflowSequenceTrainer(Path graphPath, SequenceExampleTransformer<T> exampleTransformer, SequenceOutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int loggingInterval, long seed, String initOp, String trainOp, String getLossOp, String predictOp) -
Method Summary
Modifier and TypeMethodDescriptionintReturns the number of times the train method has been invoked.voidprotected voidpreTrainingHook(org.tensorflow.Session session, SequenceDataset<T> examples) toString()train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Field Details
-
graphPath
-
exampleTransformer
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceExampleTransformer<T extends Output<T>> exampleTransformer -
outputTransformer
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputTransformer<T extends Output<T>> outputTransformer -
minibatchSize
-
epochs
-
loggingInterval
-
seed
-
initOp
-
trainOp
-
getLossOp
-
predictOp
-
rng
-
trainInvocationCounter
-
-
Constructor Details
-
TensorflowSequenceTrainer
public TensorflowSequenceTrainer(Path graphPath, SequenceExampleTransformer<T> exampleTransformer, SequenceOutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int loggingInterval, long seed, String initOp, String trainOp, String getLossOp, String predictOp) throws IOException - Throws:
IOException
-
-
Method Details
-
postConfig
- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable- Throws:
IOException
-
train
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainerTrains a sequence prediction model using the examples in the given data set.- Specified by:
trainin interfaceSequenceTrainer<T extends Output<T>>- Parameters:
examples- the data set containing the examples.runProvenance- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
Description copied from interface:SequenceTrainerReturns the number of times the train method has been invoked.- Specified by:
getInvocationCountin interfaceSequenceTrainer<T extends Output<T>>- Returns:
- The number of times train has been invoked.
-
toString
-
preTrainingHook
-
getHyperparameterFeed
-
getProvenance
-