提交 782e39e8 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Release the fast tokenizer bert wrapper on github: https://arxiv.org/abs/2012.15524

PiperOrigin-RevId: 416671104
上级 c8bb9aa5
......@@ -46,6 +46,7 @@ from official.nlp.modeling.layers.spectral_normalization import *
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.text_layers import BertPackInputs
from official.nlp.modeling.layers.text_layers import BertTokenizer
from official.nlp.modeling.layers.text_layers import FastWordpieceBertTokenizer
from official.nlp.modeling.layers.text_layers import SentencepieceTokenizer
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
from official.nlp.modeling.layers.transformer import *
......
......@@ -13,18 +13,22 @@
# limitations under the License.
"""Keras Layers for BERT-specific preprocessing."""
# pylint: disable=g-import-not-at-top
from typing import Any, Dict, List, Optional, Union
from absl import logging
import tensorflow as tf
try:
import tensorflow_text as text # pylint: disable=g-import-not-at-top
import tensorflow_text as text
from tensorflow_text.python.ops import bert_tokenizer
except ImportError:
text = None
bert_tokenizer = None
except tf.errors.NotFoundError as e:
logging.warn("Encountered error when importing tensorflow_text: %s", e)
text = None
bert_tokenizer = None
def _check_if_tf_text_installed():
......@@ -587,3 +591,139 @@ class BertPackInputs(tf.keras.layers.Layer):
return dict(input_word_ids=_reshape(input_word_ids),
input_mask=_reshape(input_mask),
input_type_ids=_reshape(input_type_ids))
class FastWordpieceBertTokenizer(tf.keras.layers.Layer):
"""A bert tokenizer keras layer using text.FastWordpieceTokenizer.
See details: "Fast WordPiece Tokenization" (https://arxiv.org/abs/2012.15524)
"""
def __init__(self,
*,
vocab_file: str,
lower_case: bool,
tokenize_with_offsets: bool = False,
**kwargs):
"""Initializes a FastWordpieceBertTokenizer layer.
Args:
vocab_file: A Python string with the path of the vocabulary file. This is
a text file with newline-separated wordpiece tokens. This layer loads
a list of tokens from it to create text.FastWordpieceTokenizer.
lower_case: A Python boolean forwarded to text.BasicTokenizer. If true,
input text is converted to lower case (where applicable) before
tokenization. This must be set to match the way in which the vocab_file
was created.
tokenize_with_offsets: A Python boolean. If true, this layer calls
FastWordpieceTokenizer.tokenize_with_offsets() instead of plain
.tokenize() and outputs a triple of (tokens, start_offsets,
limit_offsets) insead of just tokens.
**kwargs: standard arguments to Layer().
"""
super().__init__(**kwargs)
logging.info("Initialize a FastWordpieceBertTokenizer.")
self.tokenize_with_offsets = tokenize_with_offsets
self._basic_tokenizer = bert_tokenizer.BasicTokenizer(lower_case=lower_case)
# Read the vocab file into a list of tokens to create `fast_wp_tokenizer`.
self._vocab = [line.rstrip() for line in tf.io.gfile.GFile(vocab_file)]
self._fast_wp_tokenizer = text.FastWordpieceTokenizer(
vocab=self._vocab, token_out_type=tf.int32, no_pretokenization=True)
self._special_tokens_dict = self._create_special_tokens_dict()
@property
def vocab_size(self):
return len(self._vocab)
def get_config(self):
# Skip in tf.saved_model.save(); fail if called direcly.
# We cannot just put the original, user-supplied vocab file name into
# the config, because the path has to change as the SavedModel is copied
# around.
raise NotImplementedError("Not implemented yet.")
def get_special_tokens_dict(self):
"""Returns dict of token ids, keyed by standard names for their purpose.
Returns:
A dict from Python strings to Python integers. Each key is a standard
name for a special token describing its use. (For example, "padding_id"
is what BERT traditionally calls "[PAD]" but others may call "<pad>".)
The corresponding value is the integer token id. If a special token
is not found, its entry is omitted from the dict.
The supported keys and tokens are:
* start_of_sequence_id: looked up from "[CLS]"
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
"""
return self._special_tokens_dict
def _create_special_tokens_dict(self):
"""Creates dict of token ids, keyed by standard names for their purpose."""
special_tokens = {"vocab_size": self.vocab_size}
def add_special_token(key, token):
try:
token_id = self._vocab.index(token)
special_tokens[key] = token_id
except ValueError:
# Similar as nlp.modeling.layers.BertTokenizer, if a special token
# is not found, its entry is omitted from the dict.
logging.warning("Could not find %s as token \"%s\" in vocab file", key,
token)
add_special_token("start_of_sequence_id", "[CLS]")
add_special_token("end_of_segment_id", "[SEP]")
add_special_token("padding_id", "[PAD]")
add_special_token("mask_id", "[MASK]")
return special_tokens
def _tokenize_with_offsets(self, text_input: tf.Tensor):
tokens, begin, _ = self._basic_tokenizer.tokenize_with_offsets(text_input)
wordpieces, wp_begin, wp_end = (
self._fast_wp_tokenizer.tokenize_with_offsets(tokens))
begin_expanded = tf.expand_dims(begin, axis=2)
final_begin = begin_expanded + wp_begin
final_end = begin_expanded + wp_end
return wordpieces, final_begin, final_end
def _tokenize(self, text_input: tf.Tensor):
tokens = self._basic_tokenizer.tokenize(text_input)
return self._fast_wp_tokenizer.tokenize(tokens)
def call(self, inputs: tf.Tensor):
"""Calls text.BertTokenizer on inputs.
Args:
inputs: A string Tensor of shape [batch_size].
Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are
tokens: A RaggedTensor of shape [batch_size, (words), (pieces_per_word)]
and type int32. tokens[i,j,k] contains the k-th wordpiece of the
j-th word in the i-th input.
start_offsets, limit_offsets: If tokenize_with_offsets is True,
RaggedTensors of type int64 with the same indices as tokens.
Element [i,j,k] contains the byte offset at the start, or past the
end, resp., for the k-th wordpiece of the j-th word in the i-th input.
"""
# Prepare to reshape the result to work around broken shape inference.
batch_size = tf.shape(inputs)[0]
def _reshape(rt):
values = rt.values
row_splits = rt.row_splits
row_splits = tf.reshape(row_splits, [batch_size + 1])
return tf.RaggedTensor.from_row_splits(values, row_splits)
if self.tokenize_with_offsets:
tokens, start_offsets, limit_offsets = self._tokenize_with_offsets(inputs)
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
else:
tokens = self._tokenize(inputs)
return _reshape(tokens)
......@@ -442,5 +442,109 @@ class BertPackInputsTest(tf.test.TestCase):
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
# This test covers the in-process behavior of FastWordpieceBertTokenizer layer.
class FastWordPieceBertTokenizerTest(tf.test.TestCase):
def _make_vocab_file(self, vocab, filename="vocab.txt"):
path = os.path.join(
tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
filename)
with tf.io.gfile.GFile(path, "w") as f:
f.write("\n".join(vocab + [""]))
return path
def test_uncased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
vocab_file=vocab_file, lower_case=True)
inputs = tf.constant(["abc def", "ABC DEF d"])
token_ids = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[6], [4, 5], [4]]]))
bert_tokenize.tokenize_with_offsets = True
token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, token_ids_2)
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4, 5], [8]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [5, 7], [9]]]))
self.assertEqual(bert_tokenize.vocab_size, 8)
# Repeat the above and test that case matters with lower_case=False.
def test_cased(self):
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
inputs = tf.constant(["abc def", "ABC DEF"])
token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
[[7], [1]]]))
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
[[0], [4]]]))
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
[[3], [7]]]))
def test_special_tokens_complete(self):
vocab_file = self._make_vocab_file(
["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=1,
start_of_sequence_id=3,
end_of_segment_id=4,
mask_id=5,
vocab_size=7))
def test_special_tokens_partial(self):
# [UNK] token is required by fast wordpiece tokenizer.
vocab_file = self._make_vocab_file(
["[PAD]", "[CLS]", "[SEP]", "[UNK]"])
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
vocab_file=vocab_file, lower_case=True)
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=1,
end_of_segment_id=2,
vocab_size=4)) # No mask_id,
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
vocab_file = self._make_vocab_file(
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
def input_fn():
with tf.init_scope():
self.assertFalse(tf.executing_eagerly())
# Build a preprocessing Model.
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
bert_tokenizer = text_layers.FastWordpieceBertTokenizer(
vocab_file=vocab_file, lower_case=True)
special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
for k, v in special_tokens_dict.items():
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
tokens = bert_tokenizer(sentences)
packed_inputs = text_layers.BertPackInputs(
4, special_tokens_dict=special_tokens_dict)(tokens)
preprocessing = tf.keras.Model(sentences, packed_inputs)
# Map the dataset.
ds = tf.data.Dataset.from_tensors(
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
return ds
def model_fn(features, labels, mode):
del labels # Unused.
return tf.estimator.EstimatorSpec(mode=mode,
predictions=features["input_word_ids"])
estimator = tf.estimator.Estimator(model_fn=model_fn)
outputs = list(estimator.predict(input_fn))
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
[2, 4, 5, 3]]))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册