Class TensorflowTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.interop.tensorflow.TensorflowTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
Trainer for Tensorflow. Expects the underlying Tensorflow graph to have specific placeholders and
targets listed below.
TensorflowModel.INPUT_NAME
- the input minibatch.TensorflowModel.OUTPUT_NAME
- the predicted output.TARGET
- the output to predict.TRAIN
- the train function to run (usually a single step of SGD).TRAINING_LOSS
- the loss tensor to extract for logging.EPOCH
- the current epoch number, used for gradient scaling.IS_TRAINING
- a boolean placeholder to turn on dropout or other training specific functionality.INIT
- the function to initialise the graph.
This trainer uses the serialisation functionality in TensorflowUtil
, as opposed to a SavedModel or a checkpoint.
N.B. Tensorflow support is experimental and may change without a major version bump.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
-
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final String
static final String
static final String
static final String
static final String
static final String
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsConstructorDescriptionTensorflowTrainer
(byte[] graphDef, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) Constructs a Trainer for a tensorflow graph.TensorflowTrainer
(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) Constructs a Trainer for a tensorflow graph. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.
-
Field Details
-
TARGET
-
TRAIN
-
TRAINING_LOSS
- See Also:
-
EPOCH
-
IS_TRAINING
- See Also:
-
INIT
-
-
Constructor Details
-
TensorflowTrainer
public TensorflowTrainer(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) throws IOException Constructs a Trainer for a tensorflow graph.- Parameters:
graphPath
- The path to the graph protobuf. Must have the targets and placeholders specified above.exampleTransformer
- The example transformer to convert a TribuoExample
into aTensor
.outputTransformer
- The output transformer to convert a TribuoOutput
into aTensor
and back. This encodes the output type.minibatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.- Throws:
IOException
- If the graphPath is invalid or failed to load.
-
TensorflowTrainer
public TensorflowTrainer(byte[] graphDef, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) Constructs a Trainer for a tensorflow graph.- Parameters:
graphDef
- The graph definition as a byte array. Must have the targets and placeholders specified above.exampleTransformer
- The example transformer to convert a TribuoExample
into aTensor
.outputTransformer
- The output transformer to convert a TribuoOutput
into aTensor
and back. This encodes the output type.minibatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
IOException
-
train
-
toString
-
getInvocationCount
Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
getProvenance
-