Package org.tribuo.interop.onnx
Class LabelOneVOneTransformer
java.lang.Object
org.tribuo.interop.onnx.LabelTransformer
org.tribuo.interop.onnx.LabelOneVOneTransformer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,Serializable
,OutputTransformer<Label>
,ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>
Can convert an
OnnxValue
into a Prediction
or a Label
.
Accepts:
- a tuple (tensor, float tensor) - as produced by the bare ONNX ML operations (e.g., SVMClassifier).
- a single float tensor.
Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.Fields inherited from class org.tribuo.interop.onnx.LabelTransformer
generatesProbabilities
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ConstructorDescriptionConstructs a Label transformer that operates on a one v one output and produces scores via voting. -
Method Summary
Modifier and TypeMethodDescriptionstatic LabelOneVOneTransformer
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.protected float[][]
getBatchPredictions
(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.void
toString()
Methods inherited from class org.tribuo.interop.onnx.LabelTransformer
equals, generatesProbabilities, getProvenance, getTypeWitness, hashCode, serialize, transformToBatchOutput, transformToBatchPrediction, transformToOutput, transformToPrediction
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Constructor Details
-
LabelOneVOneTransformer
public LabelOneVOneTransformer()Constructs a Label transformer that operates on a one v one output and produces scores via voting.
-
-
Method Details
-
postConfig
public void postConfig() -
deserializeFromProto
public static LabelOneVOneTransformer deserializeFromProto(int version, String className, com.google.protobuf.Any message) Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
-
getBatchPredictions
protected float[][] getBatchPredictions(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.It unfolds one-v-one predictions into a score vector using voting. This is used if the model directly outputs the ONNX
SVMClassifier
node, as skl2onnx unpacks it for you.Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.
- Overrides:
getBatchPredictions
in classLabelTransformer
- Parameters:
inputs
- The onnx model output.outputIDInfo
- The output id mapping.- Returns:
- A 2d array of outputs, the first dimension is batch size, the second dimension is the output space.
-
toString
- Overrides:
toString
in classLabelTransformer
-