public final class TensorflowTrainer<T extends Output<T>> extends Object implements Trainer<T>
TensorflowModel.INPUT_NAME
- the input minibatch.TensorflowModel.OUTPUT_NAME
- the predicted output.TARGET
- the output to predict.TRAIN
- the train function to run (usually a single step of SGD).TRAINING_LOSS
- the loss tensor to extract for logging.EPOCH
- the current epoch number, used for gradient scaling.IS_TRAINING
- a boolean placeholder to turn on dropout or other training specific functionality.INIT
- the function to initialise the graph.
This trainer uses the serialisation functionality in TensorflowUtil
, as opposed to a SavedModel or a checkpoint.
N.B. Tensorflow support is experimental and may change without a major version bump.
Modifier and Type | Class and Description |
---|---|
static class |
TensorflowTrainer.TensorflowTrainerProvenance |
Modifier and Type | Field and Description |
---|---|
static String |
EPOCH |
static String |
INIT |
static String |
IS_TRAINING |
static String |
TARGET |
static String |
TRAIN |
static String |
TRAINING_LOSS |
DEFAULT_SEED
Constructor and Description |
---|
TensorflowTrainer(byte[] graphDef,
ExampleTransformer<T> exampleTransformer,
OutputTransformer<T> outputTransformer,
int minibatchSize,
int epochs,
int testBatchSize)
Constructs a Trainer for a tensorflow graph.
|
TensorflowTrainer(Path graphPath,
ExampleTransformer<T> exampleTransformer,
OutputTransformer<T> outputTransformer,
int minibatchSize,
int epochs,
int testBatchSize)
Constructs a Trainer for a tensorflow graph.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
Model<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public static final String TARGET
public static final String TRAIN
public static final String TRAINING_LOSS
public static final String EPOCH
public static final String IS_TRAINING
public static final String INIT
public TensorflowTrainer(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize) throws IOException
graphPath
- The path to the graph protobuf. Must have the targets and placeholders specified above.exampleTransformer
- The example transformer to convert a Tribuo Example
into a Tensor
.outputTransformer
- The output transformer to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.minibatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.IOException
- If the graphPath is invalid or failed to load.public TensorflowTrainer(byte[] graphDef, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs, int testBatchSize)
graphDef
- The graph definition as a byte array. Must have the targets and placeholders specified above.exampleTransformer
- The example transformer to convert a Tribuo Example
into a Tensor
.outputTransformer
- The output transformer to convert a Tribuo Output
into a Tensor
and back. This encodes the output type.minibatchSize
- The minibatch size to use in training.epochs
- The number of SGD epochs to run.testBatchSize
- The minibatch size to use at test time.public void postConfig() throws IOException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
IOException
public Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.