Class BERTFeatureExtractor<T extends Output<T>>
- Type Parameters:
T
- The output type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,AutoCloseable
,TextFeatureExtractor<T>
,TextPipeline
Assumes that the BERT is an ONNX model generated by HuggingFace Transformers and exported using their export tool.
The tokenizer is expected to be a HuggingFace Transformers tokenizer config json file.
-
Nested Class Summary
Modifier and TypeClassDescriptionstatic class
CLI options for running BERT.static enum
The type of output pooling to perform. -
Field Summary
Modifier and TypeFieldDescriptionstatic final String
Input name for the attention mask.static final String
Default classification token name.static final String
Output name for the classification token output.static final String
Input name for the token ids.static final long
Mask value.static final String
Default separator token name.static final String
Metadata key for the token value stored in a TribuoExample
.static final String
Output name for the token level outputs.static final String
Input name for the token type ids.static final long
Token type value for the first sentence.static final String
Default unknown token name. -
Constructor Summary
ConstructorDescriptionBERTFeatureExtractor
(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath) Constructs a BERTFeatureExtractor.BERTFeatureExtractor
(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor. -
Method Summary
Modifier and TypeMethodDescriptionvoid
close()
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.extractExample
(List<String> tokens) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractExample
(List<String> tokens, T output) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample
(List<String> tokens, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample
(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.int
Returns the maximum length this BERT will accept.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
getVocab()
Returns the vocabulary that this BERTFeatureExtractor understands.static void
Test harness for running a BERT model and inspecting the output.void
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.void
reconfigureOrtSession
(ai.onnxruntime.OrtSession.SessionOptions options) Reconstructs the OrtSession using the supplied options.
-
Field Details
-
INPUT_IDS
Input name for the token ids.- See Also:
-
ATTENTION_MASK
Input name for the attention mask.- See Also:
-
TOKEN_TYPE_IDS
Input name for the token type ids.- See Also:
-
TOKEN_OUTPUT
Output name for the token level outputs.- See Also:
-
CLS_OUTPUT
Output name for the classification token output.- See Also:
-
CLASSIFICATION_TOKEN
Default classification token name.- See Also:
-
SEPARATOR_TOKEN
Default separator token name.- See Also:
-
UNKNOWN_TOKEN
Default unknown token name.- See Also:
-
TOKEN_METADATA
Metadata key for the token value stored in a TribuoExample
.- See Also:
-
MASK_VALUE
public static final long MASK_VALUEMask value.- See Also:
-
TOKEN_TYPE_VALUE
public static final long TOKEN_TYPE_VALUEToken type value for the first sentence.- See Also:
-
-
Constructor Details
-
BERTFeatureExtractor
Constructs a BERTFeatureExtractor.- Parameters:
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.
-
BERTFeatureExtractor
public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor.- Parameters:
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.pooling
- The pooling type for extracted Examples.maxLength
- The maximum number of wordpieces.useCUDA
- Set to true to enable CUDA.
-
-
Method Details
-
postConfig
public void postConfig() throws com.oracle.labs.mlrg.olcut.config.PropertyException- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
com.oracle.labs.mlrg.olcut.config.PropertyException
-
getProvenance
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance() -
reconfigureOrtSession
public void reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) throws ai.onnxruntime.OrtException Reconstructs the OrtSession using the supplied options. This allows the use of different computation backends and configurations.- Parameters:
options
- The new session options.- Throws:
ai.onnxruntime.OrtException
- If the native runtime failed to rebuild itself.
-
getMaxLength
public int getMaxLength()Returns the maximum length this BERT will accept.- Returns:
- The maximum number of tokens (including [CLS] and [SEP], so the maximum is effectively 2 less than this).
-
getVocab
Returns the vocabulary that this BERTFeatureExtractor understands.- Returns:
- The vocabulary.
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and come from the [CLS] token.
Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and are controlled by the output pooling field.
Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractSequenceExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. If
stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
extractSequenceExample
public SequenceExample<T> extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. The output list must be the same length as the number of tokens. If
stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
close
public void close() throws ai.onnxruntime.OrtException- Specified by:
close
in interfaceAutoCloseable
- Throws:
ai.onnxruntime.OrtException
-
extract
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.- Specified by:
extract
in interfaceTextFeatureExtractor<T extends Output<T>>
- Parameters:
output
- The output object.data
- The input text.- Returns:
- An example containing BERT embedding features and the requested output.
-
process
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.- Specified by:
process
in interfaceTextPipeline
- Parameters:
tag
- A tag to prefix all the generated feature names with.data
- The input text.- Returns:
- The BERT features for the supplied data.
-
main
Test harness for running a BERT model and inspecting the output.- Parameters:
args
- The CLI arguments.- Throws:
IOException
- If the files couldn't be read or written to.ai.onnxruntime.OrtException
- If the BERT model failed to load, or threw an exception during computation.
-