Class TensorFlowSequenceTrainer<T extends Output<T>>

java.lang.Object
org.tribuo.interop.tensorflow.sequence.TensorFlowSequenceTrainer<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, SequenceTrainer<T>

public class TensorFlowSequenceTrainer<T extends Output<T>> extends Object implements SequenceTrainer<T>
A trainer for SequenceModels which use an underlying TensorFlow graph.

N.B. TensorFlow support is experimental and may change without a major version bump.

  • Field Details

    • graphPath

      @Config(mandatory=true, description="Path to the protobuf containing the TensorFlow graph.") protected Path graphPath
    • featureConverter

      @Config(mandatory=true, description="Sequence feature extractor.") protected SequenceFeatureConverter featureConverter
    • outputConverter

      @Config(mandatory=true, description="Sequence output extractor.") protected SequenceOutputConverter<T extends Output<T>> outputConverter
    • minibatchSize

      @Config(description="Minibatch size") protected int minibatchSize
    • epochs

      @Config(description="Number of SGD epochs to run.") protected int epochs
    • loggingInterval

      @Config(description="Logging interval to print the loss.") protected int loggingInterval
    • seed

      @Config(description="Seed for the RNG.") protected long seed
    • trainOp

      @Config(mandatory=true, description="Name of the training operation.") protected String trainOp
    • getLossOp

      @Config(mandatory=true, description="Name of the loss operation (to inspect the loss).") protected String getLossOp
    • predictOp

      @Config(mandatory=true, description="Name of the prediction operation.") protected String predictOp
    • rng

      protected SplittableRandom rng
    • trainInvocationCounter

      protected int trainInvocationCounter
  • Constructor Details

    • TensorFlowSequenceTrainer

      public TensorFlowSequenceTrainer(Path graphPath, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, int minibatchSize, int epochs, int loggingInterval, long seed, String trainOp, String getLossOp, String predictOp) throws IOException
      Constructs a TensorFlowSequenceTrainer using the specified parameters.
      Parameters:
      graphPath - The path to the TF graph.
      featureConverter - The feature conversion object.
      outputConverter - The output conversion object.
      minibatchSize - The training minibatch size.
      epochs - The number of training epochs.
      loggingInterval - The logging interval.
      seed - The RNG seed.
      trainOp - The name of the training operation.
      getLossOp - The name of the loss operation.
      predictOp - The name of the prediction operation.
      Throws:
      IOException - If the graph can not be read from disk.
  • Method Details

    • postConfig

      public void postConfig() throws IOException
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
      Throws:
      IOException
    • train

      public SequenceModel<T> train(SequenceDataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: SequenceTrainer
      Trains a sequence prediction model using the examples in the given data set.
      Specified by:
      train in interface SequenceTrainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: SequenceTrainer
      Returns the number of times the train method has been invoked.
      Specified by:
      getInvocationCount in interface SequenceTrainer<T extends Output<T>>
      Returns:
      The number of times train has been invoked.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • preTrainingHook

      protected void preTrainingHook(org.tensorflow.Session session, SequenceDataset<T> examples)
      A hook for modifying the session state before training starts.

      This should not mutate any state in the trainer.

      Parameters:
      session - The session to modify.
      examples - The dataset.
    • getHyperparameterFeed

      protected TensorMap getHyperparameterFeed()
      Build any necessary non-data parameter tensors.

      The default implementation returns an empty map, and should be overridden if necessary.

      This should not mutate any state in the trainer.

      Returns:
      The parameter tensors.
    • getProvenance

      public TrainerProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<T extends Output<T>>