Class BinaryCrossEntropy
java.lang.Object
org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>,SGDObjective<SGDVector>,MultiLabelObjective
A multilabel version of binary cross entropy loss which expects logits.
Generates a probabilistic model, and uses a SigmoidNormalizer.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionGenerates a newVectorNormalizerwhich normalizes the predictions into a suitable format.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenancebooleanReturns true.lossAndGradient(SGDVector truth, SGDVector prediction) doubleThe default prediction threshold for creating the output.toString()Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
-
Constructor Details
-
BinaryCrossEntropy
public BinaryCrossEntropy()Constructs a BinaryCrossEntropy objective.
-
-
Method Details
-
lossAndGradient
public com.oracle.labs.mlrg.olcut.util.Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) Returns aPairofDoubleandSGDVectorrepresenting the loss and per label gradients respectively.The prediction vector is transformed to produce the per label gradient and returned.
- Specified by:
lossAndGradientin interfaceSGDObjective<SGDVector>- Parameters:
truth- The true label idprediction- The prediction for each label id- Returns:
- A Pair of the score and per label gradient.
-
getNormalizer
Description copied from interface:MultiLabelObjectiveGenerates a newVectorNormalizerwhich normalizes the predictions into a suitable format.- Specified by:
getNormalizerin interfaceMultiLabelObjective- Returns:
- The vector normalizer for this objective.
-
isProbabilistic
public boolean isProbabilistic()Returns true.- Specified by:
isProbabilisticin interfaceMultiLabelObjective- Returns:
- True.
-
threshold
public double threshold()Description copied from interface:MultiLabelObjectiveThe default prediction threshold for creating the output.- Specified by:
thresholdin interfaceMultiLabelObjective- Returns:
- The threshold.
-
toString
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()- Specified by:
getProvenancein interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
-