Class WeightedEnsembleModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.ensemble.EnsembleModel<T>
org.tribuo.ensemble.WeightedEnsembleModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ONNXExportable, ProtoSerializable<org.tribuo.protos.core.ModelProto>

public final class WeightedEnsembleModel<T extends Output<T>> extends EnsembleModel<T> implements ONNXExportable
An ensemble model that uses weights to combine the ensemble member predictions.
See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
    • weights

      protected final float[] weights
      The ensemble member combination weights.
    • combiner

      protected final EnsembleCombiner<T extends Output<T>> combiner
      The ensemble combination function.
  • Constructor Details

  • Method Details

    • deserializeFromProto

      public static WeightedEnsembleModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Description copied from class: Model
      Generates an excuse for an example.

      This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.

      This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Specified by:
      getExcuse in class EnsembleModel<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      An optional excuse object. The optional is empty if this model does not provide excuses.
    • copy

      protected EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels)
      Description copied from class: EnsembleModel
      Copies this ensemble model.
      Specified by:
      copy in class EnsembleModel<T extends Output<T>>
      Parameters:
      name - The new name.
      newProvenance - The new provenance.
      newModels - The new models.
      Returns:
      A copy of the ensemble model.
    • createEnsembleFromExistingModels

      public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner)
      Creates an ensemble from existing models. The model outputs are combined using uniform weights.

      Uses the feature and output domain from the first model as the ensemble model's domains. The individual ensemble members use the domains that they contain.

      If the output domains don't cover the same dimensions then it throws IllegalArgumentException.

      Type Parameters:
      T - The output type.
      Parameters:
      name - The ensemble name.
      models - The ensemble members.
      combiner - The combination function.
      Returns:
      A weighted ensemble model.
    • createEnsembleFromExistingModels

      public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner, float[] weights)
      Creates an ensemble from existing models.

      Uses the feature and output domain from the first model as the ensemble model's domains. The individual ensemble members use the domains that they contain.

      If the output domains don't cover the same dimensions then it throws IllegalArgumentException. If the weights aren't the same length as the models it throws IllegalArgumentException.

      Type Parameters:
      T - The output type.
      Parameters:
      name - The ensemble name.
      models - The ensemble members.
      combiner - The combination function.
      weights - The model combination weights.
      Returns:
      A weighted ensemble model.
    • exportONNXModel

      public ai.onnx.proto.OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion)
      Exports this EnsembleModel as an ONNX model.

      Note if the ensemble members or the ensemble combination function do not implement ONNXExportable then this method will throw UnsupportedOperationException.

      Specified by:
      exportONNXModel in interface ONNXExportable
      Parameters:
      domain - A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
      modelVersion - A version number for this model.
      Returns:
      An ONNX ModelProto representing the model.
    • writeONNXGraph

      public ONNXNode writeONNXGraph(ONNXRef<?> input)
      Description copied from interface: ONNXExportable
      Writes this Model into OnnxMl.GraphProto.Builder inside the input's ONNXContext.
      Specified by:
      writeONNXGraph in interface ONNXExportable
      Parameters:
      input - The input to the model graph.
      Returns:
      the output node of the model graph.
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Overrides:
      serialize in class Model<T extends Output<T>>
      Returns:
      The protobuf.