public final class BinaryCrossEntropy extends Object implements MultiLabelObjective
Generates a probabilistic model, and uses a SigmoidNormalizer
.
Constructor and Description |
---|
BinaryCrossEntropy()
Constructs a BinaryCrossEntropy objective.
|
Modifier and Type | Method and Description |
---|---|
VectorNormalizer |
getNormalizer()
Generates a new
VectorNormalizer which normalizes the predictions into a suitable format. |
com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance |
getProvenance() |
boolean |
isProbabilistic()
Returns true.
|
com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> |
lossAndGradient(SGDVector truth,
SGDVector prediction)
|
double |
threshold()
The default prediction threshold for creating the output.
|
String |
toString() |
public BinaryCrossEntropy()
public com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction)
Pair
of Double
and SGDVector
representing the loss
and per label gradients respectively.
The prediction vector is transformed to produce the per label gradient and returned.
lossAndGradient
in interface SGDObjective<SGDVector>
truth
- The true label idprediction
- The prediction for each label idpublic VectorNormalizer getNormalizer()
MultiLabelObjective
VectorNormalizer
which normalizes the predictions into a suitable format.getNormalizer
in interface MultiLabelObjective
public boolean isProbabilistic()
isProbabilistic
in interface MultiLabelObjective
public double threshold()
MultiLabelObjective
threshold
in interface MultiLabelObjective
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.