Interface SequenceOutputConverter<T extends Output<T>>

All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable

public interface SequenceOutputConverter<T extends Output<T>> extends com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable
Converts a TensorFlow output tensor into a list of predictions, and a Tribuo sequence example into a Tensorflow tensor suitable for training.

N.B. TensorFlow support is experimental and may change without a major version bump.

  • Method Details

    • decode

      List<Prediction<T>> decode(org.tensorflow.Tensor output, SequenceExample<T> input, ImmutableOutputInfo<T> labelMap)
      Decode a tensor of graph output into a list of predictions for the input sequence.
      Parameters:
      output - graph output
      input - original input sequence example
      labelMap - label domain
      Returns:
      the model's decoded prediction for the input sequence.
    • decode

      List<List<Prediction<T>>> decode(org.tensorflow.Tensor outputs, List<SequenceExample<T>> inputBatch, ImmutableOutputInfo<T> labelMap)
      Decode graph output tensors corresponding to a batch of input sequences.
      Parameters:
      outputs - a tensor corresponding to a batch of outputs.
      inputBatch - the original input batch.
      labelMap - label domain
      Returns:
      the model's decoded predictions, one for each example in the input batch.
    • encode

      TensorMap encode(SequenceExample<T> example, ImmutableOutputInfo<T> labelMap)
      Encodes an example's label as a feed dict.
      Parameters:
      example - the input example
      labelMap - label domain
      Returns:
      a map from graph placeholder names to their fed-in values.
    • encode

      TensorMap encode(List<SequenceExample<T>> batch, ImmutableOutputInfo<T> labelMap)
      Encodes a batch of labels as a feed dict.
      Parameters:
      batch - a batch of examples.
      labelMap - label domain
      Returns:
      a map from graph placeholder names to their fed-in values.
    • getTypeWitness

      default Class<T> getTypeWitness()
      The type witness used when deserializing the TensorFlow model from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Returns:
      The output class this object produces.