Class MultiLabelConverter

java.lang.Object
org.tribuo.interop.tensorflow.MultiLabelConverter
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, Serializable, OutputConverter<MultiLabel>

public class MultiLabelConverter extends Object implements OutputConverter<MultiLabel>
Can convert a MultiLabel into a Tensor containing a binary encoding of the label vector and can convert a TFloat16 or TFloat32 into a Prediction or a MultiLabel.

Predictions are thresholded at THRESHOLD, probabilities above this are considered to be present in the output.

See Also:
  • Field Details

    • THRESHOLD

      public static final double THRESHOLD
      The threshold to determine if a label has been predicted.
      See Also:
  • Constructor Details

    • MultiLabelConverter

      public MultiLabelConverter()
      Constructs a MultiLabelConverter.
  • Method Details

    • loss

      public BiFunction<org.tensorflow.op.Ops,com.oracle.labs.mlrg.olcut.util.Pair<org.tensorflow.op.core.Placeholder<? extends org.tensorflow.types.family.TNumber>,org.tensorflow.Operand<org.tensorflow.types.family.TNumber>>,org.tensorflow.Operand<org.tensorflow.types.family.TNumber>> loss()
      Returns a sigmoid cross-entropy loss.
      Specified by:
      loss in interface OutputConverter<MultiLabel>
      Returns:
      The sigmoid cross-entropy loss.
    • outputTransformFunction

      public <V extends org.tensorflow.types.family.TNumber> BiFunction<org.tensorflow.op.Ops,org.tensorflow.Operand<V>,org.tensorflow.op.Op> outputTransformFunction()
      Applies a softmax.
      Specified by:
      outputTransformFunction in interface OutputConverter<MultiLabel>
      Type Parameters:
      V - The softmax input type (should be TFloat32).
      Returns:
      A function which applies a softmax.
    • convertToPrediction

      public Prediction<MultiLabel> convertToPrediction(org.tensorflow.Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo, int numValidFeatures, Example<MultiLabel> example)
      Description copied from interface: OutputConverter
      Converts a Tensor into a Prediction.
      Specified by:
      convertToPrediction in interface OutputConverter<MultiLabel>
      Parameters:
      tensor - The tensor to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      example - The example to insert into the prediction.
      Returns:
      A prediction object.
    • convertToOutput

      public MultiLabel convertToOutput(org.tensorflow.Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputConverter
      Converts a Tensor into the specified output type.
      Specified by:
      convertToOutput in interface OutputConverter<MultiLabel>
      Parameters:
      tensor - The tensor to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A output.
    • convertToBatchPrediction

      public List<Prediction<MultiLabel>> convertToBatchPrediction(org.tensorflow.Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo, int[] numValidFeatures, List<Example<MultiLabel>> examples)
      Description copied from interface: OutputConverter
      Converts a Tensor containing multiple outputs into a list of Predictions.
      Specified by:
      convertToBatchPrediction in interface OutputConverter<MultiLabel>
      Parameters:
      tensor - The tensor to convert.
      outputIDInfo - The output info to use to identify the outputs.
      numValidFeatures - The number of valid features used by the prediction.
      examples - The example to insert into the prediction.
      Returns:
      A list of predictions.
    • convertToBatchOutput

      public List<MultiLabel> convertToBatchOutput(org.tensorflow.Tensor tensor, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputConverter
      Converts a Tensor containing multiple outputs into a list of Outputs.
      Specified by:
      convertToBatchOutput in interface OutputConverter<MultiLabel>
      Parameters:
      tensor - The tensor to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A list of outputs.
    • convertToTensor

      public org.tensorflow.Tensor convertToTensor(MultiLabel example, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputConverter
      Converts an Output into a Tensor representing it's output.
      Specified by:
      convertToTensor in interface OutputConverter<MultiLabel>
      Parameters:
      example - The output to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A Tensor representing this output.
    • convertToTensor

      public org.tensorflow.Tensor convertToTensor(List<Example<MultiLabel>> examples, ImmutableOutputInfo<MultiLabel> outputIDInfo)
      Description copied from interface: OutputConverter
      Converts a list of Example into a Tensor representing all the outputs in the list. It accepts a list of Example rather than a list of Output for efficiency reasons.
      Specified by:
      convertToTensor in interface OutputConverter<MultiLabel>
      Parameters:
      examples - The examples to convert.
      outputIDInfo - The output info to use to identify the outputs.
      Returns:
      A Tensor representing all the supplied Outputs.
    • generatesProbabilities

      public boolean generatesProbabilities()
      Description copied from interface: OutputConverter
      Does this OutputConverter generate probabilities.
      Specified by:
      generatesProbabilities in interface OutputConverter<MultiLabel>
      Returns:
      True if it produces a probability distribution in the Prediction.
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>