Class TensorflowCheckpointTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.TensorflowCheckpointTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
public final class TensorflowCheckpointTrainer<T extends Output<T>>
extends Object
implements Trainer<T>
Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and
targets listed below.
TensorflowModel.INPUT_NAME
- the input minibatch.TensorflowModel.OUTPUT_NAME
- the predicted output.TensorflowTrainer.TARGET
- the output to predict.TensorflowTrainer.TRAIN
- the train function to run (usually a single step of SGD).TensorflowTrainer.TRAINING_LOSS
- the loss tensor to extract for logging.TensorflowTrainer.EPOCH
- the current epoch number, used for gradient scaling.TensorflowTrainer.IS_TRAINING
- a boolean placeholder to turn on dropout or other training specific functionality.TensorflowTrainer.INIT
- the function to initialise the graph.
This trainer uses the native Tensorflow serialisation functionality and saves to a checkpoint on disk. It's much more
fragile than the TensorflowTrainer
.
N.B. Tensorflow support is experimental and may change without a major version bump.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
-
Field Summary
FieldsFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsConstructorDescriptionTensorflowCheckpointTrainer
(Path graphPath, Path checkpointRootPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs) Builds a trainer using the supplied graph and arguments. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.
-
Field Details
-
MODEL_FILENAME
- See Also:
-
-
Constructor Details
-
TensorflowCheckpointTrainer
public TensorflowCheckpointTrainer(Path graphPath, Path checkpointRootPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs) throws IOException Builds a trainer using the supplied graph and arguments.- Parameters:
graphPath
- The graph to load.checkpointRootPath
- The checkpoint path to save to.exampleTransformer
- The feature transformer.outputTransformer
- The output transformer.minibatchSize
- The training batch size.epochs
- The number of training epochs.- Throws:
IOException
- If the graph failed to load.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
IOException
-
train
-
getInvocationCount
Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
toString
-
getProvenance
-