Class LabelOneVOneTransformer
java.lang.Object
org.tribuo.interop.onnx.LabelTransformer
org.tribuo.interop.onnx.LabelOneVOneTransformer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>,Serializable,OutputTransformer<Label>,ProtoSerializable<org.tribuo.interop.onnx.protos.OutputTransformerProto>
Can convert an
OnnxValue into a Prediction or a Label.
Accepts:
- a tuple (tensor, float tensor) - as produced by the bare ONNX ML operations (e.g., SVMClassifier).
- a single float tensor.
Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.
- See Also:
-
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final intProtobuf serialization version.Fields inherited from class org.tribuo.interop.onnx.LabelTransformer
generatesProbabilitiesFields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER -
Constructor Summary
ConstructorsConstructorDescriptionConstructs a Label transformer that operates on a one v one output and produces scores via voting. -
Method Summary
Modifier and TypeMethodDescriptionstatic LabelOneVOneTransformerdeserializeFromProto(int version, String className, com.google.protobuf.Any message) Deserialization factory.protected float[][]getBatchPredictions(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.voidtoString()Methods inherited from class org.tribuo.interop.onnx.LabelTransformer
equals, generatesProbabilities, getProvenance, getTypeWitness, hashCode, serialize, transformToBatchOutput, transformToBatchPrediction, transformToOutput, transformToPrediction
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Constructor Details
-
LabelOneVOneTransformer
public LabelOneVOneTransformer()Constructs a Label transformer that operates on a one v one output and produces scores via voting.
-
-
Method Details
-
postConfig
public void postConfig() -
deserializeFromProto
public static LabelOneVOneTransformer deserializeFromProto(int version, String className, com.google.protobuf.Any message) Deserialization factory.- Parameters:
version- The serialized object version.className- The class name.message- The serialized data.- Returns:
- The deserialized object.
-
getBatchPredictions
protected float[][] getBatchPredictions(List<ai.onnxruntime.OnnxValue> inputs, ImmutableOutputInfo<Label> outputIDInfo) Rationalises the output of an onnx model into a standard format suitable for downstream work in Tribuo.It unfolds one-v-one predictions into a score vector using voting. This is used if the model directly outputs the ONNX
SVMClassifiernode, as skl2onnx unpacks it for you.Operates on either a list containing a single tensor [batch_size,(numOutputs*(numOutputs-1))/2], or a list containing two tensors where the second one contains the one-v-one predictions as before.
- Overrides:
getBatchPredictionsin classLabelTransformer- Parameters:
inputs- The onnx model output.outputIDInfo- The output id mapping.- Returns:
- A 2d array of outputs, the first dimension is batch size, the second dimension is the output space.
-
toString
- Overrides:
toStringin classLabelTransformer
-