Class TensorFlowTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.TensorFlowTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
Trainer for TensorFlow. Expects the underlying TensorFlow graph to have named placeholders for
the inputs, ground truth outputs and a named output operation. The output operation should be
before any softmax or sigmoid non-linearities to allow the use of more optimized loss functions.
This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1.
This trainer uses the serialisation functionality in TensorFlowUtil
, or a TF checkpoint.
N.B. TensorFlow support is experimental and may change without a major version bump.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
static enum
The model format to emit. -
Field Summary
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsConstructorDescriptionTensorFlowTrainer
(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) Constructs a Trainer for a TensorFlow graph.TensorFlowTrainer
(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) Constructs a Trainer for a TensorFlow graph.TensorFlowTrainer
(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) Constructs a Trainer for a TensorFlow graph.TensorFlowTrainer
(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) Constructs a Trainer for a TensorFlow graph.TensorFlowTrainer
(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) Constructs a Trainer for a TensorFlow graph.TensorFlowTrainer
(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) Constructs a Trainer for a TensorFlow graph. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
Trains a predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.
-
Constructor Details
-
TensorFlowTrainer
public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) throws IOExceptionConstructs a Trainer for a TensorFlow graph. Stores the model parameters inside the Tribuo model.- Parameters:
graphPath
- Path to the graph definition on disk. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.- Throws:
IOException
- If the graph could not be loaded from the supplied path.
-
TensorFlowTrainer
public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) throws IOExceptionConstructs a Trainer for a TensorFlow graph. Stores the model parameters in a TensorFlow checkpoint.- Parameters:
graphPath
- Path to the graph definition on disk. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The path to save out the TensorFlow checkpoint.- Throws:
IOException
- If the graph could not be loaded from the supplied path.
-
TensorFlowTrainer
public TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) Constructs a Trainer for a TensorFlow graph. Stores the model parameters inside the Tribuo model.- Parameters:
graphDef
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.
-
TensorFlowTrainer
public TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) Constructs a Trainer for a TensorFlow graph. Stores the model parameters in a TensorFlow checkpoint.- Parameters:
graphDef
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The path to save out the TensorFlow checkpoint.
-
TensorFlowTrainer
public TensorFlowTrainer(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) Constructs a Trainer for a TensorFlow graph. Stores the model parameters inside the Tribuo model.The graph can be closed after the trainer is constructed. Tribuo maintains a copy of the graphdef inside the trainer.
- Parameters:
graph
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.
-
TensorFlowTrainer
public TensorFlowTrainer(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String, Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) Constructs a Trainer for a TensorFlow graph. Stores the model parameters in a TensorFlow checkpoint.The graph can be closed after the trainer is constructed. Tribuo maintains a copy of the graphdef inside the trainer.
- Parameters:
graph
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a TribuoExample
into aTensor
.outputConverter
- The output converter to convert a TribuoOutput
into aTensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The checkpoint path, if using checkpoints.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
IOException
-
train
Description copied from interface:Trainer
Trains a predictive model using the examples in the given data set. -
train
-
toString
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
getProvenance
-