Class LabelTransformer

java.lang.Object
org.tribuo.interop.onnx.LabelTransformer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, OutputTransformer<Label>, ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>
Direct Known Subclasses:
LabelOneVOneTransformer

public class LabelTransformer extends Object implements OutputTransformer<Label>
Can convert an OnnxValue into a Prediction or a Label.

Accepts:

  • a tuple (tensor,sequence(map(long,float))) - as produced by sk2onnx.
  • a tuple (tensor, float tensor) - as produced by the bare ONNX ML operations (e.g., SVMClassifier).
  • a single float tensor - as produced by pytorch.
By default it assumes the model scores are probabilities.

The scores must be the same length as the number of output dimensions. If the scores are the outputs of a one v one predictor between all classes then LabelOneVOneTransformer performs the appropriate scoring operation, and this class will throw an exception when used on such input.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
    • generatesProbabilities

      @Config(description="Does this transformer produce probabilistic outputs?") protected boolean generatesProbabilities
  • Constructor Details

    • LabelTransformer

      public LabelTransformer()
      Constructs a LabelTransformer which assumes the model emits probabilities.
    • LabelTransformer

      public LabelTransformer(boolean generatesProbabilities)
      Constructs a LabelTransformer.
      Parameters:
      generatesProbabilities - Does this model emit probabilistic outputs?
  • Method Details

    • deserializeFromProto

      public static LabelTransformer deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • transformToPrediction

      public Prediction<Label> transformToPrediction(List<ai.onnxruntime.OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int numValidFeatures, Example<Label> example)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue into a Prediction.
      Specified by:
      transformToPrediction in interface OutputTransformer<Label>
      Parameters:
      tensor - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      example - The example to insert into the prediction.
      Returns:
      A prediction object.
    • transformToOutput

      public Label transformToOutput(List<ai.onnxruntime.OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue into the specified output type.
      Specified by:
      transformToOutput in interface OutputTransformer<Label>
      Parameters:
      tensor - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A output.
    • getBatchPredictions

      protected float[][] getBatchPredictions(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo)
      Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.
      Parameters:
      inputs - The onnx model output.
      outputIDInfo - The output id mapping.
      Returns:
      A 2d array of outputs, the first dimension is batch size, the second dimension is the output space.
    • transformToBatchPrediction

      public List<Prediction<Label>> transformToBatchPrediction(List<ai.onnxruntime.OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo, int[] numValidFeatures, List<Example<Label>> examples)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue containing multiple outputs into a list of Predictions.
      Specified by:
      transformToBatchPrediction in interface OutputTransformer<Label>
      Parameters:
      tensor - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      examples - The example to insert into the prediction.
      Returns:
      A list of predictions.
    • transformToBatchOutput

      public List<Label> transformToBatchOutput(List<ai.onnxruntime.OnnxValue> tensor, ImmutableOutputInfo<Label> outputIDInfo)
      Description copied from interface: OutputTransformer
      Converts a OnnxValue containing multiple outputs into a list of Outputs.
      Specified by:
      transformToBatchOutput in interface OutputTransformer<Label>
      Parameters:
      tensor - The value to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A list of outputs.
    • generatesProbabilities

      public boolean generatesProbabilities()
      Description copied from interface: OutputTransformer
      Does this OutputTransformer generate probabilities.
      Specified by:
      generatesProbabilities in interface OutputTransformer<Label>
      Returns:
      True if it produces a probability distribution in the Prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getTypeWitness

      public Class<Label> getTypeWitness()
      Description copied from interface: OutputTransformer
      The type witness used when deserializing the ONNX model from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Specified by:
      getTypeWitness in interface OutputTransformer<Label>
      Returns:
      The output class this object produces.
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object
    • serialize

      public org.tribuo.interop.onnx.protos.OutputTransformerProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>
      Returns:
      The protobuf.
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>