public final class TensorFlowTrainer<T extends Output<T>> extends Object implements Trainer<T>
This trainer only works with graphs setup for minibatches. To recover single example training just use a batch size of 1.
This trainer uses the serialisation functionality in TensorFlowUtil
, or a TF checkpoint.
N.B. TensorFlow support is experimental and may change without a major version bump.
Modifier and Type | Class and Description |
---|---|
static class |
TensorFlowTrainer.TensorFlowTrainerProvenance |
static class |
TensorFlowTrainer.TFModelFormat
The model format to emit.
|
DEFAULT_SEED
Constructor and Description |
---|
TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval)
Constructs a Trainer for a TensorFlow graph.
|
TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval,
Path checkpointPath)
Constructs a Trainer for a TensorFlow graph.
|
TensorFlowTrainer(org.tensorflow.Graph graph,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval)
Constructs a Trainer for a TensorFlow graph.
|
TensorFlowTrainer(org.tensorflow.Graph graph,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval,
Path checkpointPath)
Constructs a Trainer for a TensorFlow graph.
|
TensorFlowTrainer(Path graphPath,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval)
Constructs a Trainer for a TensorFlow graph.
|
TensorFlowTrainer(Path graphPath,
String outputName,
GradientOptimiser optimiser,
Map<String,Float> gradientParams,
FeatureConverter featureConverter,
OutputConverter<T> outputConverter,
int trainBatchSize,
int epochs,
int testBatchSize,
int loggingInterval,
Path checkpointPath)
Constructs a Trainer for a TensorFlow graph.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
TensorFlowModel<T> |
train(Dataset<T> examples)
Trains a predictive model using the examples in the given data set.
|
TensorFlowModel<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval) throws IOException
graphPath
- Path to the graph definition on disk. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.IOException
- If the graph could not be loaded from the supplied path.public TensorFlowTrainer(Path graphPath, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath) throws IOException
graphPath
- Path to the graph definition on disk. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The path to save out the TensorFlow checkpoint.IOException
- If the graph could not be loaded from the supplied path.public TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval)
graphDef
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.public TensorFlowTrainer(org.tensorflow.proto.framework.GraphDef graphDef, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath)
graphDef
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The path to save out the TensorFlow checkpoint.public TensorFlowTrainer(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval)
The graph can be closed after the trainer is constructed. Tribuo maintains a copy of the graphdef inside the trainer.
graph
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.public TensorFlowTrainer(org.tensorflow.Graph graph, String outputName, GradientOptimiser optimiser, Map<String,Float> gradientParams, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int trainBatchSize, int epochs, int testBatchSize, int loggingInterval, Path checkpointPath)
The graph can be closed after the trainer is constructed. Tribuo maintains a copy of the graphdef inside the trainer.
graph
- The graph definition. Must have the necessary targets and placeholders.outputName
- The name of the output operation.optimiser
- The gradient optimiser.gradientParams
- The parameters of the gradient optimiser.featureConverter
- The example converter to convert a Tribuo Example
into a Tensor
.outputConverter
- The output converter to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.trainBatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.loggingInterval
- The logging interval. Set to -1 to quiesce the loss level logging.checkpointPath
- The checkpoint path, if using checkpoints.public void postConfig() throws IOException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
IOException
public TensorFlowModel<T> train(Dataset<T> examples)
Trainer
public TensorFlowModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.