Class LabelOneVOneTransformer

java.lang.Object
org.tribuo.interop.onnx.LabelTransformer
org.tribuo.interop.onnx.LabelOneVOneTransformer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, OutputTransformer<Label>, ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>

public final class LabelOneVOneTransformer extends LabelTransformer
Can convert an OnnxValue into a Prediction or a Label.

Accepts:

  • a tuple (tensor, float tensor) - as produced by the bare ONNX ML operations (e.g., SVMClassifier).
  • a single float tensor.
It attempts to parse the output as if it's a vector of predictions from one-v-one classifiers for each class pair. This is the kind of output produced by the ONNX SVMClassifier node, but the ONNX spec is not clear about how this output should be parsed, and ONNX Runtime produces a two element output for binary problems when a strict one-v-one classifier only produces a single output. As a result, this class may need to be updated as ONNX Runtime or the ONNX spec itself evolve.

Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Constructor Details

    • LabelOneVOneTransformer

      public LabelOneVOneTransformer()
      Constructs a Label transformer that operates on a one v one output and produces scores via voting.
  • Method Details

    • postConfig

      public void postConfig()
    • deserializeFromProto

      public static LabelOneVOneTransformer deserializeFromProto(int version, String className, com.google.protobuf.Any message)
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
    • getBatchPredictions

      protected float[][] getBatchPredictions(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo)
      Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.

      It unfolds one-v-one predictions into a score vector using voting. This is used if the model directly outputs the ONNX SVMClassifier node, as skl2onnx unpacks it for you.

      Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.

      Overrides:
      getBatchPredictions in class LabelTransformer
      Parameters:
      inputs - The onnx model output.
      outputIDInfo - The output id mapping.
      Returns:
      A 2d array of outputs, the first dimension is batch size, the second dimension is the output space.
    • toString

      public String toString()
      Overrides:
      toString in class LabelTransformer