Class BinaryCrossEntropy
java.lang.Object
org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,SGDObjective<SGDVector>
,MultiLabelObjective
A multilabel version of binary cross entropy loss which expects logits.
Generates a probabilistic model, and uses a SigmoidNormalizer
.
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionGenerates a newVectorNormalizer
which normalizes the predictions into a suitable format.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
boolean
Returns true.lossAndGradient
(SGDVector truth, SGDVector prediction) double
The default prediction threshold for creating the output.toString()
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
-
Constructor Details
-
BinaryCrossEntropy
public BinaryCrossEntropy()Constructs a BinaryCrossEntropy objective.
-
-
Method Details
-
lossAndGradient
public com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) Returns aPair
ofDouble
andSGDVector
representing the loss and per label gradients respectively.The prediction vector is transformed to produce the per label gradient and returned.
- Specified by:
lossAndGradient
in interfaceSGDObjective<SGDVector>
- Parameters:
truth
- The true label idprediction
- The prediction for each label id- Returns:
- A Pair of the score and per label gradient.
-
getNormalizer
Description copied from interface:MultiLabelObjective
Generates a newVectorNormalizer
which normalizes the predictions into a suitable format.- Specified by:
getNormalizer
in interfaceMultiLabelObjective
- Returns:
- The vector normalizer for this objective.
-
isProbabilistic
public boolean isProbabilistic()Returns true.- Specified by:
isProbabilistic
in interfaceMultiLabelObjective
- Returns:
- True.
-
threshold
public double threshold()Description copied from interface:MultiLabelObjective
The default prediction threshold for creating the output.- Specified by:
threshold
in interfaceMultiLabelObjective
- Returns:
- The threshold.
-
toString
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
-