Class TensorflowSequenceTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.sequence.TensorflowSequenceTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SequenceTrainer<T>
public class TensorflowSequenceTrainer<T extends Output<T>>
extends Object
implements SequenceTrainer<T>
A trainer for SequenceModels which use an underlying Tensorflow graph.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic class
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected int
protected SequenceExampleTransformer
<T> protected String
protected Path
protected String
protected int
protected int
protected SequenceOutputTransformer
<T> protected String
protected SplittableRandom
protected long
protected int
protected String
-
Constructor Summary
ConstructorsConstructorDescriptionTensorflowSequenceTrainer
(Path graphPath, SequenceExampleTransformer<T> exampleTransformer, SequenceOutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int loggingInterval, long seed, String initOp, String trainOp, String getLossOp, String predictOp) -
Method Summary
Modifier and TypeMethodDescriptionint
Returns the number of times the train method has been invoked.void
protected void
preTrainingHook
(org.tensorflow.Session session, SequenceDataset<T> examples) toString()
train
(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sequence prediction model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface org.tribuo.sequence.SequenceTrainer
train
-
Field Details
-
graphPath
-
exampleTransformer
@Config(mandatory=true, description="Sequence feature extractor.") protected SequenceExampleTransformer<T extends Output<T>> exampleTransformer -
outputTransformer
@Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputTransformer<T extends Output<T>> outputTransformer -
minibatchSize
-
epochs
-
loggingInterval
-
seed
-
initOp
-
trainOp
-
getLossOp
-
predictOp
-
rng
-
trainInvocationCounter
-
-
Constructor Details
-
TensorflowSequenceTrainer
public TensorflowSequenceTrainer(Path graphPath, SequenceExampleTransformer<T> exampleTransformer, SequenceOutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int loggingInterval, long seed, String initOp, String trainOp, String getLossOp, String predictOp) throws IOException - Throws:
IOException
-
-
Method Details
-
postConfig
- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
IOException
-
train
public SequenceModel<T> train(SequenceDataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SequenceTrainer
Trains a sequence prediction model using the examples in the given data set.- Specified by:
train
in interfaceSequenceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
getInvocationCount
Description copied from interface:SequenceTrainer
Returns the number of times the train method has been invoked.- Specified by:
getInvocationCount
in interfaceSequenceTrainer<T extends Output<T>>
- Returns:
- The number of times train has been invoked.
-
toString
-
preTrainingHook
-
getHyperparameterFeed
-
getProvenance
-