public final class TensorflowCheckpointTrainer<T extends Output<T>> extends Object implements Trainer<T>
TensorflowModel.INPUT_NAME
- the input minibatch.TensorflowModel.OUTPUT_NAME
- the predicted output.TensorflowTrainer.TARGET
- the output to predict.TensorflowTrainer.TRAIN
- the train function to run (usually a single step of SGD).TensorflowTrainer.TRAINING_LOSS
- the loss tensor to extract for logging.TensorflowTrainer.EPOCH
- the current epoch number, used for gradient scaling.TensorflowTrainer.IS_TRAINING
- a boolean placeholder to turn on dropout or other training specific functionality.TensorflowTrainer.INIT
- the function to initialise the graph.
This trainer uses the native Tensorflow serialisation functionality and saves to a checkpoint on disk. It's much more
fragile than the TensorflowTrainer
.
N.B. Tensorflow support is experimental and may change without a major version bump.
Modifier and Type | Class and Description |
---|---|
static class |
TensorflowCheckpointTrainer.TensorflowCheckpointTrainerProvenance |
Modifier and Type | Field and Description |
---|---|
static String |
MODEL_FILENAME |
DEFAULT_SEED
Constructor and Description |
---|
TensorflowCheckpointTrainer(Path graphPath,
Path checkpointRootPath,
ExampleTransformer<T> exampleTransformer,
OutputTransformer<T> outputTransformer,
int minibatchSize,
int epochs)
Builds a trainer using the supplied graph and arguments.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
Model<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public static final String MODEL_FILENAME
public TensorflowCheckpointTrainer(Path graphPath, Path checkpointRootPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int minibatchSize, int epochs) throws IOException
graphPath
- The graph to load.checkpointRootPath
- The checkpoint path to save to.exampleTransformer
- The feature transformer.outputTransformer
- The output transformer.minibatchSize
- The training batch size.epochs
- The number of training epochs.IOException
- If the graph failed to load.public void postConfig() throws IOException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
IOException
public Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.