Interface SequenceOutputConverter<T extends Output<T>>
- All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,ProtoSerializable<org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto>
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,Serializable
public interface SequenceOutputConverter<T extends Output<T>>
extends com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable
Converts a TensorFlow output tensor into a list of predictions, and a Tribuo sequence example into
a Tensorflow tensor suitable for training.
N.B. TensorFlow support is experimental and may change without a major version bump.
-
Field Summary
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Method Summary
Modifier and TypeMethodDescriptionList<List<Prediction<T>>>
decode
(org.tensorflow.Tensor outputs, List<SequenceExample<T>> inputBatch, ImmutableOutputInfo<T> labelMap) Decode graph output tensors corresponding to a batch of input sequences.List<Prediction<T>>
decode
(org.tensorflow.Tensor output, SequenceExample<T> input, ImmutableOutputInfo<T> labelMap) Decode a tensor of graph output into a list of predictions for the input sequence.encode
(List<SequenceExample<T>> batch, ImmutableOutputInfo<T> labelMap) Encodes a batch of labels as a feed dict.encode
(SequenceExample<T> example, ImmutableOutputInfo<T> labelMap) Encodes an example's label as a feed dict.The type witness used when deserializing the TensorFlow model from a protobuf.Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.protos.ProtoSerializable
serialize
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Method Details
-
decode
List<Prediction<T>> decode(org.tensorflow.Tensor output, SequenceExample<T> input, ImmutableOutputInfo<T> labelMap) Decode a tensor of graph output into a list of predictions for the input sequence.- Parameters:
output
- graph outputinput
- original input sequence examplelabelMap
- label domain- Returns:
- the model's decoded prediction for the input sequence.
-
decode
List<List<Prediction<T>>> decode(org.tensorflow.Tensor outputs, List<SequenceExample<T>> inputBatch, ImmutableOutputInfo<T> labelMap) Decode graph output tensors corresponding to a batch of input sequences.- Parameters:
outputs
- a tensor corresponding to a batch of outputs.inputBatch
- the original input batch.labelMap
- label domain- Returns:
- the model's decoded predictions, one for each example in the input batch.
-
encode
Encodes an example's label as a feed dict.- Parameters:
example
- the input examplelabelMap
- label domain- Returns:
- a map from graph placeholder names to their fed-in values.
-
encode
Encodes a batch of labels as a feed dict.- Parameters:
batch
- a batch of examples.labelMap
- label domain- Returns:
- a map from graph placeholder names to their fed-in values.
-
getTypeWitness
The type witness used when deserializing the TensorFlow model from a protobuf.The default implementation throws
UnsupportedOperationException
for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.- Returns:
- The output class this object produces.
-