Class AveragingCombiner

java.lang.Object
org.tribuo.regression.ensemble.AveragingCombiner
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, EnsembleCombiner<Regressor>, ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>

public class AveragingCombiner extends Object implements EnsembleCombiner<Regressor>
A combiner which performs a weighted or unweighted average of the predicted regressors independently across the output dimensions.
See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Constructor Details

    • AveragingCombiner

      public AveragingCombiner()
      Constructs an averaging combiner.
  • Method Details

    • deserializeFromProto

      public static AveragingCombiner deserializeFromProto(int version, String className, com.google.protobuf.Any message)
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
    • serialize

      public org.tribuo.protos.core.EnsembleCombinerProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>
      Returns:
      The protobuf.
    • combine

      public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions)
      Description copied from interface: EnsembleCombiner
      Combine the predictions.
      Specified by:
      combine in interface EnsembleCombiner<Regressor>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      Returns:
      The ensemble prediction.
    • combine

      public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights)
      Description copied from interface: EnsembleCombiner
      Combine the supplied predictions. predictions.size() must equal weights.length.
      Specified by:
      combine in interface EnsembleCombiner<Regressor>
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      weights - The weights to use for each prediction.
      Returns:
      The ensemble prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
    • getTypeWitness

      public Class<Regressor> getTypeWitness()
      Description copied from interface: EnsembleCombiner
      The type witness used when deserializing the combiner from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Specified by:
      getTypeWitness in interface EnsembleCombiner<Regressor>
      Returns:
      The output class this object produces.
    • exportCombiner

      public ONNXNode exportCombiner(ONNXNode input)
      Exports this averaging combiner, writing constructed nodes into the ONNXContext governing input and returning the leaf node of the combiner.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<Regressor>
      Parameters:
      input - The node to combine
      Returns:
      A node representing the final average operation.
    • exportCombiner

      public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight)
      Exports this averaging combiner, writing constructed nodes into the ONNXContext governing input and returning the leaf node of the combiner.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      Specified by:
      exportCombiner in interface EnsembleCombiner<Regressor>
      Type Parameters:
      T - The type of the weights input reference.
      Parameters:
      input - The node to combine
      weight - The node of weights to use in combining.
      Returns:
      A node representing the final average operation.
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object