Class BERTFeatureExtractor<T extends Output<T>>

java.lang.Object
org.tribuo.interop.onnx.extractors.BERTFeatureExtractor<T>
Type Parameters:
T - The output type.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>, AutoCloseable, TextFeatureExtractor<T>, TextPipeline

public class BERTFeatureExtractor<T extends Output<T>> extends Object implements AutoCloseable, TextFeatureExtractor<T>, TextPipeline
Builds examples and sequence examples using features from BERT.

Assumes that the BERT is an ONNX model generated by HuggingFace Transformers and exported using their export tool.

The tokenizer is expected to be a HuggingFace Transformers tokenizer config json file.

  • Field Details

  • Constructor Details

    • BERTFeatureExtractor

      public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath)
      Constructs a BERTFeatureExtractor.
      Parameters:
      outputFactory - The output factory to use for building any unknown outputs.
      modelPath - The path to BERT in onnx format.
      tokenizerPath - The path to a Huggingface tokenizer json file.
    • BERTFeatureExtractor

      public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA)
      Constructs a BERTFeatureExtractor.
      Parameters:
      outputFactory - The output factory to use for building any unknown outputs.
      modelPath - The path to BERT in onnx format.
      tokenizerPath - The path to a Huggingface tokenizer json file.
      pooling - The pooling type for extracted Examples.
      maxLength - The maximum number of wordpieces.
      useCUDA - Set to true to enable CUDA.
  • Method Details

    • postConfig

      public void postConfig() throws com.oracle.labs.mlrg.olcut.config.PropertyException
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
      Throws:
      com.oracle.labs.mlrg.olcut.config.PropertyException
    • getProvenance

      public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<T extends Output<T>>
    • reconfigureOrtSession

      public void reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) throws ai.onnxruntime.OrtException
      Reconstructs the OrtSession using the supplied options. This allows the use of different computation backends and configurations.
      Parameters:
      options - The new session options.
      Throws:
      ai.onnxruntime.OrtException - If the native runtime failed to rebuild itself.
    • getMaxLength

      public int getMaxLength()
      Returns the maximum length this BERT will accept.
      Returns:
      The maximum number of tokens (including [CLS] and [SEP], so the maximum is effectively 2 less than this).
    • getVocab

      public Set<String> getVocab()
      Returns the vocabulary that this BERTFeatureExtractor understands.
      Returns:
      The vocabulary.
    • extractExample

      public Example<T> extractExample(List<String> tokens)
      Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.

      The features of the returned example are dense, and come from the [CLS] token.

      Throws IllegalArgumentException if the list is longer than getMaxLength(). Throws IllegalStateException if the BERT model failed to produce an output.

      Parameters:
      tokens - The input tokens. Should be tokenized using the Tokenizer this BERT expects.
      Returns:
      A dense example representing the pooled output from BERT for the input tokens.
    • extractExample

      public Example<T> extractExample(List<String> tokens, T output)
      Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.

      The features of the returned example are dense, and are controlled by the output pooling field.

      Throws IllegalArgumentException if the list is longer than getMaxLength(). Throws IllegalStateException if the BERT model failed to produce an output.

      Parameters:
      tokens - The input tokens. Should be tokenized using the Tokenizer this BERT expects.
      output - The ground truth output for this example.
      Returns:
      A dense example representing the pooled output from BERT for the input tokens.
    • extractSequenceExample

      public SequenceExample<T> extractSequenceExample(List<String> tokens, boolean stripSentenceMarkers)
      Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.

      The features of each example are dense. If stripSentenceMarkers is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.

      Throws IllegalArgumentException if the list is longer than getMaxLength(). Throws IllegalStateException if the BERT model failed to produce an output.

      Parameters:
      tokens - The input tokens. Should be tokenized using the Tokenizer this BERT expects.
      stripSentenceMarkers - Remove the [CLS] and [SEP] tokens from the returned example.
      Returns:
      A dense sequence example representing the token level output from BERT.
    • extractSequenceExample

      public SequenceExample<T> extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers)
      Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.

      The features of each example are dense. The output list must be the same length as the number of tokens. If stripSentenceMarkers is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.

      Throws IllegalArgumentException if the list is longer than getMaxLength(). Throws IllegalStateException if the BERT model failed to produce an output.

      Parameters:
      tokens - The input tokens. Should be tokenized using the Tokenizer this BERT expects.
      output - The ground truth output for this example.
      stripSentenceMarkers - Remove the [CLS] and [SEP] tokens from the returned example.
      Returns:
      A dense sequence example representing the token level output from BERT.
    • close

      public void close() throws ai.onnxruntime.OrtException
      Specified by:
      close in interface AutoCloseable
      Throws:
      ai.onnxruntime.OrtException
    • extract

      public Example<T> extract(T output, String data)
      Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer than maxLength - 2 (to account for [CLS] and [SEP] tokens), and then passes the token list to extractExample(java.util.List<java.lang.String>).
      Specified by:
      extract in interface TextFeatureExtractor<T extends Output<T>>
      Parameters:
      output - The output object.
      data - The input text.
      Returns:
      An example containing BERT embedding features and the requested output.
    • process

      public List<Feature> process(String tag, String data)
      Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer than maxLength - 2 (to account for [CLS] and [SEP] tokens), and then passes the token list to extractExample(java.util.List<java.lang.String>).
      Specified by:
      process in interface TextPipeline
      Parameters:
      tag - A tag to prefix all the generated feature names with.
      data - The input text.
      Returns:
      The BERT features for the supplied data.
    • main

      public static void main(String[] args) throws IOException, ai.onnxruntime.OrtException
      Test harness for running a BERT model and inspecting the output.
      Parameters:
      args - The CLI arguments.
      Throws:
      IOException - If the files couldn't be read or written to.
      ai.onnxruntime.OrtException - If the BERT model failed to load, or threw an exception during computation.