Class FullyWeightedVotingCombiner

java.lang.Object
org.tribuo.classification.ensemble.FullyWeightedVotingCombiner
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, EnsembleCombiner<Label>, ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>

public final class FullyWeightedVotingCombiner extends Object implements EnsembleCombiner<Label>
A combiner which performs a weighted or unweighted vote across the predicted labels.

This uses the full distribution of predictions from each ensemble member, unlike VotingCombiner which uses the most likely prediction for each ensemble member.

See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Constructor Details

    • FullyWeightedVotingCombiner

      public FullyWeightedVotingCombiner()
      Constructs a weighted voting combiner.
  • Method Details

    • deserializeFromProto

      public static FullyWeightedVotingCombiner deserializeFromProto(int version, String className, com.google.protobuf.Any message)
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
    • serialize

      public org.tribuo.protos.core.EnsembleCombinerProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>
      Returns:
      The protobuf.
    • combine

      public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions)
      Description copied from interface: EnsembleCombiner
      Combine the predictions.
      Specified by:
      combine in interface EnsembleCombiner<Label>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      Returns:
      The ensemble prediction.
    • combine

      public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights)
      Description copied from interface: EnsembleCombiner
      Combine the supplied predictions. predictions.size() must equal weights.length.
      Specified by:
      combine in interface EnsembleCombiner<Label>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      weights - The weights to use for each prediction.
      Returns:
      The ensemble prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
    • getTypeWitness

      public Class<Label> getTypeWitness()
      Description copied from interface: EnsembleCombiner
      The type witness used when deserializing the combiner from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Specified by:
      getTypeWitness in interface EnsembleCombiner<Label>
      Returns:
      The output class this object produces.
    • exportCombiner

      public ONNXNode exportCombiner(ONNXNode input)
      Exports this voting combiner to ONNX.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<Label>
      Parameters:
      input - the node to be ensembled according to this implementation.
      Returns:
      The leaf node of the voting operation.
    • exportCombiner

      public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight)
      Exports this voting combiner to ONNX.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<Label>
      Type Parameters:
      T - The type of the weights input reference.
      Parameters:
      input - the node to be ensembled according to this implementation.
      weight - The node of weights for ensembling.
      Returns:
      The leaf node of the voting operation.
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object