Class BinaryCrossEntropy

java.lang.Object
org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, SGDObjective<SGDVector>, MultiLabelObjective

public final class BinaryCrossEntropy extends Object implements MultiLabelObjective
A multilabel version of binary cross entropy loss which expects logits.

Generates a probabilistic model, and uses a SigmoidNormalizer.

  • Constructor Details

    • BinaryCrossEntropy

      public BinaryCrossEntropy()
      Constructs a BinaryCrossEntropy objective.
  • Method Details

    • lossAndGradient

      public com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction)
      Returns a Pair of Double and SGDVector representing the loss and per label gradients respectively.

      The prediction vector is transformed to produce the per label gradient and returned.

      Specified by:
      lossAndGradient in interface SGDObjective<SGDVector>
      Parameters:
      truth - The true label id
      prediction - The prediction for each label id
      Returns:
      A Pair of the score and per label gradient.
    • getNormalizer

      public VectorNormalizer getNormalizer()
      Description copied from interface: MultiLabelObjective
      Generates a new VectorNormalizer which normalizes the predictions into a suitable format.
      Specified by:
      getNormalizer in interface MultiLabelObjective
      Returns:
      The vector normalizer for this objective.
    • isProbabilistic

      public boolean isProbabilistic()
      Returns true.
      Specified by:
      isProbabilistic in interface MultiLabelObjective
      Returns:
      True.
    • threshold

      public double threshold()
      Description copied from interface: MultiLabelObjective
      The default prediction threshold for creating the output.
      Specified by:
      threshold in interface MultiLabelObjective
      Returns:
      The threshold.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>