Class BERTFeatureExtractor<T extends Output<T>>
- Type Parameters:
T
- The output type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
,AutoCloseable
,TextFeatureExtractor<T>
,TextPipeline
Assumes that the BERT is an ONNX model generated by HuggingFace Transformers and exported using their export tool.
The tokenizer is expected to be a HuggingFace Transformers tokenizer config json file.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic class
CLI options for running BERT.static enum
The type of output pooling to perform. -
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionBERTFeatureExtractor
(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath) Constructs a BERTFeatureExtractor.BERTFeatureExtractor
(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor. -
Method Summary
Modifier and TypeMethodDescriptionvoid
close()
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.extractExample
(List<String> tokens) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractExample
(List<String> tokens, T output) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample
(List<String> tokens, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample
(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.int
Returns the maximum length this BERT will accept.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance
getVocab()
Returns the vocabulary that this BERTFeatureExtractor understands.static void
Test harness for running a BERT model and inspecting the output.void
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.void
reconfigureOrtSession
(ai.onnxruntime.OrtSession.SessionOptions options) Reconstructs the OrtSession using the supplied options.
-
Field Details
-
INPUT_IDS
- See Also:
-
ATTENTION_MASK
- See Also:
-
TOKEN_TYPE_IDS
- See Also:
-
TOKEN_OUTPUT
- See Also:
-
CLS_OUTPUT
- See Also:
-
CLASSIFICATION_TOKEN
- See Also:
-
SEPARATOR_TOKEN
- See Also:
-
UNKNOWN_TOKEN
- See Also:
-
TOKEN_METADATA
- See Also:
-
MASK_VALUE
public static final long MASK_VALUE- See Also:
-
TOKEN_TYPE_VALUE
public static final long TOKEN_TYPE_VALUE- See Also:
-
-
Constructor Details
-
BERTFeatureExtractor
Constructs a BERTFeatureExtractor.- Parameters:
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.
-
BERTFeatureExtractor
public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor.- Parameters:
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.pooling
- The pooling type for extracted Examples.maxLength
- The maximum number of wordpieces.useCUDA
- Set to true to enable CUDA.
-
-
Method Details
-
postConfig
public void postConfig() throws com.oracle.labs.mlrg.olcut.config.PropertyException- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Throws:
com.oracle.labs.mlrg.olcut.config.PropertyException
-
getProvenance
-
reconfigureOrtSession
public void reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) throws ai.onnxruntime.OrtException Reconstructs the OrtSession using the supplied options. This allows the use of different computation backends and configurations.- Parameters:
options
- The new session options.- Throws:
ai.onnxruntime.OrtException
- If the native runtime failed to rebuild itself.
-
getMaxLength
public int getMaxLength()Returns the maximum length this BERT will accept.- Returns:
- The maximum number of tokens (including [CLS] and [SEP], so the maximum is effectively 2 less than this).
-
getVocab
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and come from the [CLS] token.
Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and are controlled by the output pooling field.
Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractSequenceExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. If
stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
extractSequenceExample
public SequenceExample<T> extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. The output list must be the same length as the number of tokens. If
stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentException
if the list is longer thangetMaxLength()
. ThrowsIllegalStateException
if the BERT model failed to produce an output.- Parameters:
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
close
public void close() throws ai.onnxruntime.OrtException- Specified by:
close
in interfaceAutoCloseable
- Throws:
ai.onnxruntime.OrtException
-
extract
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.- Specified by:
extract
in interfaceTextFeatureExtractor<T extends Output<T>>
- Parameters:
output
- The output object.data
- The input text.- Returns:
- An example containing BERT embedding features and the requested output.
-
process
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength
- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>)
.- Specified by:
process
in interfaceTextPipeline
- Parameters:
tag
- A tag to prefix all the generated feature names with.data
- The input text.- Returns:
- The BERT features for the supplied data.
-
main
Test harness for running a BERT model and inspecting the output.- Parameters:
args
- The CLI arguments.- Throws:
IOException
- If the files couldn't be read or written to.ai.onnxruntime.OrtException
- If the BERT model failed to load, or threw an exception during computation.
-