Interface EnsembleCombiner<T extends Output<T>>

All Superinterfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable
All Known Implementing Classes:
AveragingCombiner, FullyWeightedVotingCombiner, MultiLabelVotingCombiner, VotingCombiner

public interface EnsembleCombiner<T extends Output<T>> extends com.oracle.labs.mlrg.olcut.config.Configurable, ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable
An interface for combining predictions. Implementations should be final and immutable.
  • Method Details

    • combine

      Prediction<T> combine(ImmutableOutputInfo<T> outputInfo, List<Prediction<T>> predictions)
      Combine the predictions.
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      Returns:
      The ensemble prediction.
    • combine

      Prediction<T> combine(ImmutableOutputInfo<T> outputInfo, List<Prediction<T>> predictions, float[] weights)
      Combine the supplied predictions. predictions.size() must equal weights.length.
      Parameters:
      outputInfo - The output domain.
      predictions - The predictions to combine.
      weights - The weights to use for each prediction.
      Returns:
      The ensemble prediction.
    • exportCombiner

      default ONNXNode exportCombiner(ONNXNode input)
      Exports this ensemble combiner into the ONNX context of its input.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      For compatibility reasons this method has a default implementation, though when called it will throw an IllegalStateException. In a future version this method will not have a default implementation and ensemble combiners will be required to provide ONNX support.

      Parameters:
      input - the node to be ensembled according to this implementation.
      Returns:
      The leaf node of the graph of operations added to ensemble input.
    • exportCombiner

      default <U extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, U weight)
      Exports this ensemble combiner into the ONNX context of its input.

      The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].

      For compatibility reasons this method has a default implementation, though when called it will throw an IllegalStateException. In a future version this method will not have a default implementation and ensemble combiners will be required to provide ONNX support.

      Type Parameters:
      U - The type of the weights input reference.
      Parameters:
      input - the node to be ensembled according to this implementation.
      weight - The node of weights for ensembling.
      Returns:
      The leaf node of the graph of operations added to ensemble input.
    • getTypeWitness

      default Class<T> getTypeWitness()
      The type witness used when deserializing the combiner from a protobuf.

      The default implementation throws UnsupportedOperationException for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.

      Returns:
      The output class this object produces.
    • deserialize

      static EnsembleCombiner<?> deserialize(org.tribuo.protos.core.EnsembleCombinerProto proto)
      Deserialization helper for EnsembleCombiner.
      Parameters:
      proto - The proto to deserialize.
      Returns:
      The combiner.