Package org.tribuo.interop.tensorflow
Interface OutputConverter<T extends Output<T>>
- Type Parameters:
T
- The output type.
- All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,ProtoSerializable<org.tribuo.interop.tensorflow.protos.OutputConverterProto>
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,Serializable
- All Known Implementing Classes:
LabelConverter
,MultiLabelConverter
,RegressorConverter
public interface OutputConverter<T extends Output<T>>
extends com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.interop.tensorflow.protos.OutputConverterProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable
Converts the
Output
into a Tensor
and vice versa.
Also provides the loss function for this output type, along with the function which converts the TF graph output into a well formed output float (e.g., a softmax for classification, a sigmoid for multi-label, or the identity function for regression).
N.B. TensorFlow support is experimental and may change without a major version bump.
-
Field Summary
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Method Summary
Modifier and TypeMethodDescriptionconvertToBatchOutput
(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo) Converts aTensor
containing multiple outputs into a list ofOutput
s.List<Prediction<T>>
convertToBatchPrediction
(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo, int[] numValidFeatures, List<Example<T>> examples) Converts aTensor
containing multiple outputs into a list ofPrediction
s.convertToOutput
(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo) Converts aTensor
into the specified output type.convertToPrediction
(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo, int numValidFeatures, Example<T> example) Converts aTensor
into aPrediction
.org.tensorflow.Tensor
convertToTensor
(List<Example<T>> examples, ImmutableOutputInfo<T> outputIDInfo) Converts a list ofExample
into aTensor
representing all the outputs in the list.org.tensorflow.Tensor
convertToTensor
(T output, ImmutableOutputInfo<T> outputIDInfo) Converts anOutput
into aTensor
representing it's output.boolean
Does this OutputConverter generate probabilities.The type witness used when deserializing the TensorFlow model from a protobuf.BiFunction<org.tensorflow.op.Ops,
com.oracle.labs.mlrg.olcut.util.Pair<org.tensorflow.op.core.Placeholder<? extends org.tensorflow.types.family.TNumber>, org.tensorflow.Operand<org.tensorflow.types.family.TNumber>>, org.tensorflow.Operand<org.tensorflow.types.family.TNumber>> loss()
The loss function associated with this prediction type.<U extends org.tensorflow.types.family.TNumber>
BiFunction<org.tensorflow.op.Ops,org.tensorflow.Operand<U>, org.tensorflow.op.Op> Produces an output transformation function that applies the operation to the graph from the suppliedOps
, taking a graph output operation.Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.protos.ProtoSerializable
serialize
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Method Details
-
loss
BiFunction<org.tensorflow.op.Ops,com.oracle.labs.mlrg.olcut.util.Pair<org.tensorflow.op.core.Placeholder<? extends org.tensorflow.types.family.TNumber>, loss()org.tensorflow.Operand<org.tensorflow.types.family.TNumber>>, org.tensorflow.Operand<org.tensorflow.types.family.TNumber>> The loss function associated with this prediction type.- Returns:
- The TF loss function.
-
outputTransformFunction
<U extends org.tensorflow.types.family.TNumber> BiFunction<org.tensorflow.op.Ops,org.tensorflow.Operand<U>, outputTransformFunction()org.tensorflow.op.Op> Produces an output transformation function that applies the operation to the graph from the suppliedOps
, taking a graph output operation.For example this function will apply a softmax for classification, a sigmoid for multi-label, or the identity function for regression.
- Type Parameters:
U
- The type of the graph output.- Returns:
- A function which applies the appropriate transformation function.
-
convertToPrediction
Prediction<T> convertToPrediction(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo, int numValidFeatures, Example<T> example) Converts aTensor
into aPrediction
.- Parameters:
tensor
- The tensor to convert.outputIDInfo
- The output info to use to identify the outputs.numValidFeatures
- The number of valid features used by the prediction.example
- The example to insert into the prediction.- Returns:
- A prediction object.
-
convertToOutput
Converts aTensor
into the specified output type.- Parameters:
tensor
- The tensor to convert.outputIDInfo
- The output info to use to identify the outputs.- Returns:
- A output.
-
convertToBatchPrediction
List<Prediction<T>> convertToBatchPrediction(org.tensorflow.Tensor tensor, ImmutableOutputInfo<T> outputIDInfo, int[] numValidFeatures, List<Example<T>> examples) Converts aTensor
containing multiple outputs into a list ofPrediction
s.- Parameters:
tensor
- The tensor to convert.outputIDInfo
- The output info to use to identify the outputs.numValidFeatures
- The number of valid features used by the prediction.examples
- The example to insert into the prediction.- Returns:
- A list of predictions.
-
convertToBatchOutput
Converts aTensor
containing multiple outputs into a list ofOutput
s.- Parameters:
tensor
- The tensor to convert.outputIDInfo
- The output info to use to identify the outputs.- Returns:
- A list of outputs.
-
convertToTensor
Converts anOutput
into aTensor
representing it's output.- Parameters:
output
- The output to convert.outputIDInfo
- The output info to use to identify the outputs.- Returns:
- A Tensor representing this output.
-
convertToTensor
org.tensorflow.Tensor convertToTensor(List<Example<T>> examples, ImmutableOutputInfo<T> outputIDInfo) Converts a list ofExample
into aTensor
representing all the outputs in the list. It accepts a list of Example rather than a list of Output for efficiency reasons.- Parameters:
examples
- The examples to convert.outputIDInfo
- The output info to use to identify the outputs.- Returns:
- A Tensor representing all the supplied Outputs.
-
generatesProbabilities
boolean generatesProbabilities()Does this OutputConverter generate probabilities.- Returns:
- True if it produces a probability distribution in the
Prediction
.
-
getTypeWitness
The type witness used when deserializing the TensorFlow model from a protobuf.The default implementation throws
UnsupportedOperationException
for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.- Returns:
- The output class this object produces.
-