Package org.tribuo.regression.ensemble
Class AveragingCombiner
java.lang.Object
org.tribuo.regression.ensemble.AveragingCombiner
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,Serializable
,EnsembleCombiner<Regressor>
,ProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>
A combiner which performs a weighted or unweighted average of the predicted
regressors independently across the output dimensions.
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptioncombine
(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions) Combine the predictions.combine
(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights) Combine the supplied predictions.static AveragingCombiner
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.boolean
exportCombiner
(ONNXNode input) Exports this averaging combiner, writing constructed nodes into theONNXContext
governinginput
and returning the leaf node of the combiner.exportCombiner
(ONNXNode input, T weight) Exports this averaging combiner, writing constructed nodes into theONNXContext
governinginput
and returning the leaf node of the combiner.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
The type witness used when deserializing the combiner from a protobuf.int
hashCode()
org.tribuo.protos.core.EnsembleCombinerProto
Serializes this object to a protobuf.toString()
Methods inherited from class java.lang.Object
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Constructor Details
-
AveragingCombiner
public AveragingCombiner()Constructs an averaging combiner.
-
-
Method Details
-
deserializeFromProto
public static AveragingCombiner deserializeFromProto(int version, String className, com.google.protobuf.Any message) Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
-
serialize
public org.tribuo.protos.core.EnsembleCombinerProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.- Specified by:
serialize
in interfaceProtoSerializable<org.tribuo.protos.core.EnsembleCombinerProto>
- Returns:
- The protobuf.
-
combine
public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions) Description copied from interface:EnsembleCombiner
Combine the predictions.- Specified by:
combine
in interfaceEnsembleCombiner<Regressor>
- Parameters:
outputInfo
- The output domain.predictions
- The predictions to combine.- Returns:
- The ensemble prediction.
-
combine
public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights) Description copied from interface:EnsembleCombiner
Combine the supplied predictions. predictions.size() must equal weights.length.- Specified by:
combine
in interfaceEnsembleCombiner<Regressor>
- Parameters:
outputInfo
- The output domain.predictions
- The predictions to combine.weights
- The weights to use for each prediction.- Returns:
- The ensemble prediction.
-
toString
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
-
getTypeWitness
Description copied from interface:EnsembleCombiner
The type witness used when deserializing the combiner from a protobuf.The default implementation throws
UnsupportedOperationException
for compatibility with implementations which don't use protobuf serialization. This implementation will be removed in the next major version of Tribuo.- Specified by:
getTypeWitness
in interfaceEnsembleCombiner<Regressor>
- Returns:
- The output class this object produces.
-
exportCombiner
Exports this averaging combiner, writing constructed nodes into theONNXContext
governinginput
and returning the leaf node of the combiner.The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
- Specified by:
exportCombiner
in interfaceEnsembleCombiner<Regressor>
- Parameters:
input
- The node to combine- Returns:
- A node representing the final average operation.
-
exportCombiner
Exports this averaging combiner, writing constructed nodes into theONNXContext
governinginput
and returning the leaf node of the combiner.The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
- Specified by:
exportCombiner
in interfaceEnsembleCombiner<Regressor>
- Type Parameters:
T
- The type of the weights input reference.- Parameters:
input
- The node to combineweight
- The node of weights to use in combining.- Returns:
- A node representing the final average operation.
-
equals
-
hashCode
public int hashCode()
-