Class LogMulticlass

java.lang.Object
org.tribuo.classification.sgd.objectives.LogMulticlass
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, LabelObjective, SGDObjective<Integer>

public class LogMulticlass extends Object implements LabelObjective
A multiclass version of the log loss.

Generates a probabilistic model, and uses an ExpNormalizer.

  • Constructor Details

    • LogMulticlass

      public LogMulticlass()
      Constructs a multiclass log loss.
  • Method Details

    • valueAndGradient

      @Deprecated public com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction)
      Deprecated.
      Description copied from interface: LabelObjective
      Scores a prediction, returning the loss and a vector of per label gradients.
      Specified by:
      valueAndGradient in interface LabelObjective
      Parameters:
      truth - The true label id.
      prediction - The prediction for each label id.
      Returns:
      The score and per label gradient.
    • lossAndGradient

      public com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> lossAndGradient(Integer truth, SGDVector prediction)
      Returns a Pair of Double and SGDVector representing the loss and per label gradients respectively.

      The prediction vector is transformed to produce the per label gradient and returned.

      Specified by:
      lossAndGradient in interface LabelObjective
      Specified by:
      lossAndGradient in interface SGDObjective<Integer>
      Parameters:
      truth - The true label id
      prediction - The prediction for each label id
      Returns:
      A Pair of the score and per label gradient.
    • getNormalizer

      public VectorNormalizer getNormalizer()
      Description copied from interface: LabelObjective
      Generates a new VectorNormalizer which normalizes the predictions into [0,1].
      Specified by:
      getNormalizer in interface LabelObjective
      Returns:
      The vector normalizer for this objective.
    • isProbabilistic

      public boolean isProbabilistic()
      Returns true.
      Specified by:
      isProbabilistic in interface LabelObjective
      Returns:
      True.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>