Class MultiLabelVotingCombiner

java.lang.Object
org.tribuo.multilabel.ensemble.MultiLabelVotingCombiner
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, EnsembleCombiner<MultiLabel>, ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>

public final class MultiLabelVotingCombiner extends Object implements EnsembleCombiner<MultiLabel>
A combiner which performs a weighted or unweighted vote independently across the predicted labels in each multi-label.

This uses the thresholded predictions from each ensemble member.

This class is stateless and thread safe.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Constructor Details

    • MultiLabelVotingCombiner

      public MultiLabelVotingCombiner()
      Constructs a voting combiner.
  • Method Details

    • deserializeFromProto

      public static MultiLabelVotingCombiner deserializeFromProto(int version, String className, com.google.protobuf.Any message)
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
    • serialize

      public org.tribuo.protos.core.EnsembleCombinerProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>
      Returns:
      The protobuf.
    • combine

      public Prediction<MultiLabel> combine(ImmutableOutputInfo<MultiLabel> outputInfo, List<Prediction<MultiLabel>> predictions)
      Description copied from interface: EnsembleCombiner
      Combine the predictions.
      Specified by:
      combine in interface EnsembleCombiner<MultiLabel>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      Returns:
      The ensemble prediction.
    • combine

      public Prediction<MultiLabel> combine(ImmutableOutputInfo<MultiLabel> outputInfo, List<Prediction<MultiLabel>> predictions, float[] weights)
      Description copied from interface: EnsembleCombiner
      Combine the supplied predictions. predictions.size() must equal weights.length.
      Specified by:
      combine in interface EnsembleCombiner<MultiLabel>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      weights - The weights to use for each prediction.
      Returns:
      The ensemble prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
    • getTypeWitness

      public Class<MultiLabel> getTypeWitness()
      Description copied from interface: EnsembleCombiner
      The type witness used when deserializing the combiner from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Specified by:
      getTypeWitness in interface EnsembleCombiner<MultiLabel>
      Returns:
      The output class this object produces.
    • exportCombiner

      public ONNXNode exportCombiner(ONNXNode input)
      Exports this voting combiner to ONNX.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<MultiLabel>
      Parameters:
      input - The input tensor to combine.
      Returns:
      the final node proto representing the voting operation.
    • exportCombiner

      public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight)
      Exports this voting combiner to ONNX

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<MultiLabel>
      Type Parameters:
      T - The type of the weights input reference.
      Parameters:
      input - The input tensor to combine.
      weight - The combination weight node.
      Returns:
      the final node proto representing the voting operation.
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object