Class MultiLabelTransformer

java.lang.Object
org.tribuo.interop.onnx.MultiLabelTransformer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, OutputTransformer<MultiLabel>, ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>

public class MultiLabelTransformer extends Object implements OutputTransformer<MultiLabel>
Can convert an OnnxValue into a Prediction or a MultiLabel.

Accepts a single tensor representing the scores of each label in the batch.

By default predictions are thresholded at DEFAULT_THRESHOLD, scores above this are considered to be present in the output, and the model output is assumed to be probabilistic.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
    • DEFAULT_THRESHOLD

      public static final double DEFAULT_THRESHOLD
      The default threshold for conversion into a label.
      See Also:
  • Constructor Details

    • MultiLabelTransformer

      public MultiLabelTransformer()
      Constructs a MultiLabelTransformer with a threshold of DEFAULT_THRESHOLD which assumes the model emits probabilities.
    • MultiLabelTransformer

      public MultiLabelTransformer(double threshold, boolean generatesProbabilities)
      Constructs a MultiLabelTransformer with the supplied threshold.
      Parameters:
      threshold - The threshold to set. Must be between 0 and 1 if generatesProbabilities is true.
      generatesProbabilities - Does this model produce probabilistic outputs.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • deserializeFromProto

      public static MultiLabelTransformer deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • transformToPrediction

      public Prediction<MultiLabel> transformToPrediction(List<ai.onnxruntime.OnnxValue> value, ImmutableOutputInfo<MultiLabel> outputIDInfo, int numValidFeatures, Example<MultiLabel> example)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue into a Prediction.
      Specified by:
      transformToPrediction in interface OutputTransformer<MultiLabel>
      Parameters:
      value - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      example - The example to insert into the prediction.
      Returns:
      A prediction object.
    • transformToOutput

      public MultiLabel transformToOutput(List<ai.onnxruntime.OnnxValue> value, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue into the specified output type.
      Specified by:
      transformToOutput in interface OutputTransformer<MultiLabel>
      Parameters:
      value - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A output.
    • transformToBatchPrediction

      public List<Prediction<MultiLabel>> transformToBatchPrediction(List<ai.onnxruntime.OnnxValue> value, ImmutableOutputInfo<MultiLabel> outputIDInfo, int[] numValidFeatures, List<Example<MultiLabel>> examples)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue containing multiple outputs into a list of Predictions.
      Specified by:
      transformToBatchPrediction in interface OutputTransformer<MultiLabel>
      Parameters:
      value - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      examples - The example to insert into the prediction.
      Returns:
      A list of predictions.
    • transformToBatchOutput

      public List<MultiLabel> transformToBatchOutput(List<ai.onnxruntime.OnnxValue> value, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue containing multiple outputs into a list of Outputs.
      Specified by:
      transformToBatchOutput in interface OutputTransformer<MultiLabel>
      Parameters:
      value - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A list of outputs.
    • generatesProbabilities

      public boolean generatesProbabilities()
      Description copied from interface: OutputTransformer
      Does this OutputTransformer generate probabilities.
      Specified by:
      generatesProbabilities in interface OutputTransformer<MultiLabel>
      Returns:
      True if it produces a probability distribution in the Prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getTypeWitness

      public Class<MultiLabel> getTypeWitness()
      Description copied from interface: OutputTransformer
      The type witness used when deserializing the ONNX model from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Specified by:
      getTypeWitness in interface OutputTransformer<MultiLabel>
      Returns:
      The output class this object produces.
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object
    • serialize

      public org.tribuo.interop.onnx.protos.OutputTransformerProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>
      Returns:
      The protobuf.
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>