T
- The output type.public class BERTFeatureExtractor<T extends Output<T>> extends Object implements AutoCloseable, TextFeatureExtractor<T>, TextPipeline
Assumes that the BERT is an ONNX model generated by HuggingFace Transformers and exported using their export tool.
The tokenizer is expected to be a HuggingFace Transformers tokenizer config json file.
Modifier and Type | Class and Description |
---|---|
static class |
BERTFeatureExtractor.BERTFeatureExtractorOptions
CLI options for running BERT.
|
static class |
BERTFeatureExtractor.OutputPooling
The type of output pooling to perform.
|
Modifier and Type | Field and Description |
---|---|
static String |
ATTENTION_MASK |
static String |
CLASSIFICATION_TOKEN |
static String |
CLS_OUTPUT |
static String |
INPUT_IDS |
static long |
MASK_VALUE |
static String |
SEPARATOR_TOKEN |
static String |
TOKEN_METADATA |
static String |
TOKEN_OUTPUT |
static String |
TOKEN_TYPE_IDS |
static long |
TOKEN_TYPE_VALUE |
static String |
UNKNOWN_TOKEN |
Constructor and Description |
---|
BERTFeatureExtractor(OutputFactory<T> outputFactory,
Path modelPath,
Path tokenizerPath)
Constructs a BERTFeatureExtractor.
|
BERTFeatureExtractor(OutputFactory<T> outputFactory,
Path modelPath,
Path tokenizerPath,
BERTFeatureExtractor.OutputPooling pooling,
int maxLength,
boolean useCUDA)
Constructs a BERTFeatureExtractor.
|
Modifier and Type | Method and Description |
---|---|
void |
close() |
Example<T> |
extract(T output,
String data)
Tokenizes the input using the loaded tokenizer, truncates the
token list if it's longer than
maxLength - 2 (to account
for [CLS] and [SEP] tokens), and then passes the token
list to extractExample(java.util.List<java.lang.String>) . |
Example<T> |
extractExample(List<String> tokens)
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.
|
Example<T> |
extractExample(List<String> tokens,
T output)
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.
|
SequenceExample<T> |
extractSequenceExample(List<String> tokens,
boolean stripSentenceMarkers)
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.
|
SequenceExample<T> |
extractSequenceExample(List<String> tokens,
List<T> output,
boolean stripSentenceMarkers)
Passes the tokens through BERT, replacing any unknown tokens with the [UNK] token.
|
int |
getMaxLength()
Returns the maximum length this BERT will accept.
|
com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance |
getProvenance() |
Set<String> |
getVocab()
Returns the vocabulary that this BERTFeatureExtractor understands.
|
static void |
main(String[] args)
Test harness for running a BERT model and inspecting the output.
|
void |
postConfig() |
List<Feature> |
process(String tag,
String data)
Tokenizes the input using the loaded tokenizer, truncates the
token list if it's longer than
maxLength - 2 (to account
for [CLS] and [SEP] tokens), and then passes the token
list to extractExample(java.util.List<java.lang.String>) . |
void |
reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options)
Reconstructs the OrtSession using the supplied options.
|
public static final String INPUT_IDS
public static final String ATTENTION_MASK
public static final String TOKEN_TYPE_IDS
public static final String TOKEN_OUTPUT
public static final String CLS_OUTPUT
public static final String CLASSIFICATION_TOKEN
public static final String SEPARATOR_TOKEN
public static final String UNKNOWN_TOKEN
public static final String TOKEN_METADATA
public static final long MASK_VALUE
public static final long TOKEN_TYPE_VALUE
public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath)
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path modelPath, Path tokenizerPath, BERTFeatureExtractor.OutputPooling pooling, int maxLength, boolean useCUDA)
outputFactory
- The output factory to use for building any unknown outputs.modelPath
- The path to BERT in onnx format.tokenizerPath
- The path to a Huggingface tokenizer json file.pooling
- The pooling type for extracted Examples.maxLength
- The maximum number of wordpieces.useCUDA
- Set to true to enable CUDA.public void postConfig() throws com.oracle.labs.mlrg.olcut.config.PropertyException
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
com.oracle.labs.mlrg.olcut.config.PropertyException
public com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance>
public void reconfigureOrtSession(ai.onnxruntime.OrtSession.SessionOptions options) throws ai.onnxruntime.OrtException
options
- The new session options.ai.onnxruntime.OrtException
- If the native runtime failed to rebuild itself.public int getMaxLength()
public Set<String> getVocab()
public Example<T> extractExample(List<String> tokens)
The features of the returned example are dense, and come from the [CLS] token.
Throws IllegalArgumentException
if the list is longer than getMaxLength()
.
Throws IllegalStateException
if the BERT model failed to produce an output.
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.public Example<T> extractExample(List<String> tokens, T output)
The features of the returned example are dense, and are controlled by the output pooling field.
Throws IllegalArgumentException
if the list is longer than getMaxLength()
.
Throws IllegalStateException
if the BERT model failed to produce an output.
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.public SequenceExample<T> extractSequenceExample(List<String> tokens, boolean stripSentenceMarkers)
The features of each example are dense.
If stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation.
If it's false then they are left in with the appropriate unknown output set.
Throws IllegalArgumentException
if the list is longer than getMaxLength()
.
Throws IllegalStateException
if the BERT model failed to produce an output.
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.public SequenceExample<T> extractSequenceExample(List<String> tokens, List<T> output, boolean stripSentenceMarkers)
The features of each example are dense. The output list must be the same length as the number of tokens.
If stripSentenceMarkers
is true then the [CLS] and [SEP] tokens are removed before example generation.
If it's false then they are left in with the appropriate unknown output set.
Throws IllegalArgumentException
if the list is longer than getMaxLength()
.
Throws IllegalStateException
if the BERT model failed to produce an output.
tokens
- The input tokens. Should be tokenized using the Tokenizer this BERT expects.output
- The ground truth output for this example.stripSentenceMarkers
- Remove the [CLS] and [SEP] tokens from the returned example.public void close() throws ai.onnxruntime.OrtException
close
in interface AutoCloseable
ai.onnxruntime.OrtException
public Example<T> extract(T output, String data)
maxLength
- 2 (to account
for [CLS] and [SEP] tokens), and then passes the token
list to extractExample(java.util.List<java.lang.String>)
.extract
in interface TextFeatureExtractor<T extends Output<T>>
output
- The output object.data
- The input text.public List<Feature> process(String tag, String data)
maxLength
- 2 (to account
for [CLS] and [SEP] tokens), and then passes the token
list to extractExample(java.util.List<java.lang.String>)
.process
in interface TextPipeline
tag
- A tag to prefix all the generated feature names with.data
- The input text.public static void main(String[] args) throws IOException, ai.onnxruntime.OrtException
args
- The CLI arguments.IOException
- If the files couldn't be read or written to.ai.onnxruntime.OrtException
- If the BERT model failed to load, or threw an exception during computation.Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.