Class MultiLabelVotingCombiner

java.lang.Object
org.tribuo.multilabel.ensemble.MultiLabelVotingCombiner
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, EnsembleCombiner<MultiLabel>

public final class MultiLabelVotingCombiner extends Object implements EnsembleCombiner<MultiLabel>
A combiner which performs a weighted or unweighted vote independently across the predicted labels in each multi-label.

This uses the thresholded predictions from each ensemble member.

This class is stateless and thread safe.

See Also:
  • Constructor Details

    • MultiLabelVotingCombiner

      public MultiLabelVotingCombiner()
      Constructs a voting combiner.
  • Method Details

    • combine

      public Prediction<MultiLabel> combine(ImmutableOutputInfo<MultiLabel> outputInfo, List<Prediction<MultiLabel>> predictions)
      Description copied from interface: EnsembleCombiner
      Combine the predictions.
      Specified by:
      combine in interface EnsembleCombiner<MultiLabel>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      Returns:
      The ensemble prediction.
    • combine

      public Prediction<MultiLabel> combine(ImmutableOutputInfo<MultiLabel> outputInfo, List<Prediction<MultiLabel>> predictions, float[] weights)
      Description copied from interface: EnsembleCombiner
      Combine the supplied predictions. predictions.size() must equal weights.length.
      Specified by:
      combine in interface EnsembleCombiner<MultiLabel>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      weights - The weights to use for each prediction.
      Returns:
      The ensemble prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
    • exportCombiner

      public ONNXNode exportCombiner(ONNXNode input)
      Exports this voting combiner to ONNX.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<MultiLabel>
      Parameters:
      input - The input tensor to combine.
      Returns:
      the final node proto representing the voting operation.
    • exportCombiner

      public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight)
      Exports this voting combiner to ONNX

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<MultiLabel>
      Type Parameters:
      T - The type of the weights input reference.
      Parameters:
      input - The input tensor to combine.
      weight - The combination weight node.
      Returns:
      the final node proto representing the voting operation.