Class BERTFeatureExtractor<T extends Output<T>>
- Type Parameters:
T- The output type.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>,AutoCloseable,TextFeatureExtractor<T>,TextPipeline
Assumes that the BERT is an ONNX model generated by HuggingFace Transformers and exported using their export tool.
The tokenizer is expected to be a HuggingFace Transformers tokenizer config json file.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic classCLI options for running BERT.static enumThe type of output pooling to perform. -
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final StringInput name for the attention mask.static final StringDefault classification token name.static final StringOutput name for the classification token output.static final StringInput name for the token ids.static final longMask value.static final StringDefault separator token name.static final StringMetadata key for the token value stored in a TribuoExample.static final StringOutput name for the token level outputs.static final StringInput name for the token type ids.static final longToken type value for the first sentence.static final StringDefault unknown token name. -
Constructor Summary
ConstructorsConstructorDescriptionBERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath) Constructs a BERTFeatureExtractor.BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor. -
Method Summary
Modifier and TypeMethodDescriptionvoidclose()Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>).extractExample(List<String> tokens) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractExample(List<String> tokens, T output) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample(List<String> tokens, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.intReturns the maximum length this BERT will accept.com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenancegetVocab()Returns the vocabulary that this BERTFeatureExtractor understands.static voidTest harness for running a BERT model and inspecting the output.voidTokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>).voidreconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) Reconstructs the OrtSession using the supplied options.
-
Field Details
-
INPUT_IDS
-
ATTENTION_MASK
-
TOKEN_TYPE_IDS
-
TOKEN_OUTPUT
-
CLS_OUTPUT
-
CLASSIFICATION_TOKEN
-
SEPARATOR_TOKEN
-
UNKNOWN_TOKEN
-
TOKEN_METADATA
-
MASK_VALUE
public static final long MASK_VALUEMask value.- See Also:
-
TOKEN_TYPE_VALUE
public static final long TOKEN_TYPE_VALUEToken type value for the first sentence.- See Also:
-
-
Constructor Details
-
BERTFeatureExtractor
Constructs a BERTFeatureExtractor.- Parameters:
outputFactory- The output factory to use for building any unknown outputs.modelPath- The path to BERT in onnx format.tokenizerPath- The path to a Huggingface tokenizer json file.
-
BERTFeatureExtractor
public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA) Constructs a BERTFeatureExtractor.- Parameters:
outputFactory- The output factory to use for building any unknown outputs.modelPath- The path to BERT in onnx format.tokenizerPath- The path to a Huggingface tokenizer json file.pooling- The pooling type for extracted Examples.maxLength- The maximum number of wordpieces.useCUDA- Set to true to enable CUDA.
-
-
Method Details
-
postConfig
public void postConfig() throws com.oracle.labs.mlrg.olcut.config.PropertyException- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable- Throws:
com.oracle.labs.mlrg.olcut.config.PropertyException
-
getProvenance
-
reconfigureOrtSession
public void reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) throws ai.onnxruntime.OrtException Reconstructs the OrtSession using the supplied options. This allows the use of different computation backends and configurations.- Parameters:
options- The new session options.- Throws:
ai.onnxruntime.OrtException- If the native runtime failed to rebuild itself.
-
getMaxLength
public int getMaxLength()Returns the maximum length this BERT will accept.- Returns:
- The maximum number of tokens (including [CLS] and [SEP], so the maximum is effectively 2 less than this).
-
getVocab
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and come from the [CLS] token.
Throws
IllegalArgumentExceptionif the list is longer thangetMaxLength(). ThrowsIllegalStateExceptionif the BERT model failed to produce an output.- Parameters:
tokens- The input tokens. Should be tokenized using the Tokenizer this BERT expects.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of the returned example are dense, and are controlled by the output pooling field.
Throws
IllegalArgumentExceptionif the list is longer thangetMaxLength(). ThrowsIllegalStateExceptionif the BERT model failed to produce an output.- Parameters:
tokens- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output- The ground truth output for this example.- Returns:
- A dense example representing the pooled output from BERT for the input tokens.
-
extractSequenceExample
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. If
stripSentenceMarkersis true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentExceptionif the list is longer thangetMaxLength(). ThrowsIllegalStateExceptionif the BERT model failed to produce an output.- Parameters:
tokens- The input tokens. Should be tokenized using the Tokenizer this BERT expects.stripSentenceMarkers- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
extractSequenceExample
public SequenceExample<T> extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers) Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.The features of each example are dense. The output list must be the same length as the number of tokens. If
stripSentenceMarkersis true then the [CLS] and [SEP] tokens are removed before example generation. If it's false then they are left in with the appropriate unknown output set.Throws
IllegalArgumentExceptionif the list is longer thangetMaxLength(). ThrowsIllegalStateExceptionif the BERT model failed to produce an output.- Parameters:
tokens- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output- The ground truth output for this example.stripSentenceMarkers- Remove the [CLS] and [SEP] tokens from the returned example.- Returns:
- A dense sequence example representing the token level output from BERT.
-
close
public void close() throws ai.onnxruntime.OrtException- Specified by:
closein interfaceAutoCloseable- Throws:
ai.onnxruntime.OrtException
-
extract
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>).- Specified by:
extractin interfaceTextFeatureExtractor<T extends Output<T>>- Parameters:
output- The output object.data- The input text.- Returns:
- An example containing BERT embedding features and the requested output.
-
process
Tokenizes the input using the loaded tokenizer, truncates the token list if it's longer thanmaxLength- 2 (to account for [CLS] and [SEP] tokens), and then passes the token list toextractExample(java.util.List<java.lang.String>).- Specified by:
processin interfaceTextPipeline- Parameters:
tag- A tag to prefix all the generated feature names with.data- The input text.- Returns:
- The BERT features for the supplied data.
-
main
Test harness for running a BERT model and inspecting the output.- Parameters:
args- The CLI arguments.- Throws:
IOException- If the files couldn't be read or written to.ai.onnxruntime.OrtException- If the BERT model failed to load, or threw an exception during computation.
-