diff --git a/official/nlp/modeling/layers/__init__.py b/official/nlp/modeling/layers/__init__.py index d9b43cb2f4cbeadbaf3d27bfec0bb7d5daf5c4ae..f8f475d40a50d8f05d49e49ee24a6855c1ee13a7 100644 --- a/official/nlp/modeling/layers/__init__.py +++ b/official/nlp/modeling/layers/__init__.py @@ -46,6 +46,7 @@ from official.nlp.modeling.layers.spectral_normalization import * from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.text_layers import BertPackInputs from official.nlp.modeling.layers.text_layers import BertTokenizer +from official.nlp.modeling.layers.text_layers import FastWordpieceBertTokenizer from official.nlp.modeling.layers.text_layers import SentencepieceTokenizer from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense from official.nlp.modeling.layers.transformer import * diff --git a/official/nlp/modeling/layers/text_layers.py b/official/nlp/modeling/layers/text_layers.py index 9ca51ab325267a69b2c2bec72d2372e3a9a16ef3..299901d2df7504ef207f4adcf3149c9cac700350 100644 --- a/official/nlp/modeling/layers/text_layers.py +++ b/official/nlp/modeling/layers/text_layers.py @@ -13,18 +13,22 @@ # limitations under the License. """Keras Layers for BERT-specific preprocessing.""" +# pylint: disable=g-import-not-at-top from typing import Any, Dict, List, Optional, Union from absl import logging import tensorflow as tf try: - import tensorflow_text as text # pylint: disable=g-import-not-at-top + import tensorflow_text as text + from tensorflow_text.python.ops import bert_tokenizer except ImportError: text = None + bert_tokenizer = None except tf.errors.NotFoundError as e: logging.warn("Encountered error when importing tensorflow_text: %s", e) text = None + bert_tokenizer = None def _check_if_tf_text_installed(): @@ -587,3 +591,139 @@ class BertPackInputs(tf.keras.layers.Layer): return dict(input_word_ids=_reshape(input_word_ids), input_mask=_reshape(input_mask), input_type_ids=_reshape(input_type_ids)) + + +class FastWordpieceBertTokenizer(tf.keras.layers.Layer): + """A bert tokenizer keras layer using text.FastWordpieceTokenizer. + + See details: "Fast WordPiece Tokenization" (https://arxiv.org/abs/2012.15524) + """ + + def __init__(self, + *, + vocab_file: str, + lower_case: bool, + tokenize_with_offsets: bool = False, + **kwargs): + """Initializes a FastWordpieceBertTokenizer layer. + + Args: + vocab_file: A Python string with the path of the vocabulary file. This is + a text file with newline-separated wordpiece tokens. This layer loads + a list of tokens from it to create text.FastWordpieceTokenizer. + lower_case: A Python boolean forwarded to text.BasicTokenizer. If true, + input text is converted to lower case (where applicable) before + tokenization. This must be set to match the way in which the vocab_file + was created. + tokenize_with_offsets: A Python boolean. If true, this layer calls + FastWordpieceTokenizer.tokenize_with_offsets() instead of plain + .tokenize() and outputs a triple of (tokens, start_offsets, + limit_offsets) insead of just tokens. + **kwargs: standard arguments to Layer(). + """ + super().__init__(**kwargs) + logging.info("Initialize a FastWordpieceBertTokenizer.") + self.tokenize_with_offsets = tokenize_with_offsets + self._basic_tokenizer = bert_tokenizer.BasicTokenizer(lower_case=lower_case) + + # Read the vocab file into a list of tokens to create `fast_wp_tokenizer`. + self._vocab = [line.rstrip() for line in tf.io.gfile.GFile(vocab_file)] + self._fast_wp_tokenizer = text.FastWordpieceTokenizer( + vocab=self._vocab, token_out_type=tf.int32, no_pretokenization=True) + self._special_tokens_dict = self._create_special_tokens_dict() + + @property + def vocab_size(self): + return len(self._vocab) + + def get_config(self): + # Skip in tf.saved_model.save(); fail if called direcly. + # We cannot just put the original, user-supplied vocab file name into + # the config, because the path has to change as the SavedModel is copied + # around. + raise NotImplementedError("Not implemented yet.") + + def get_special_tokens_dict(self): + """Returns dict of token ids, keyed by standard names for their purpose. + + Returns: + A dict from Python strings to Python integers. Each key is a standard + name for a special token describing its use. (For example, "padding_id" + is what BERT traditionally calls "[PAD]" but others may call "".) + The corresponding value is the integer token id. If a special token + is not found, its entry is omitted from the dict. + + The supported keys and tokens are: + * start_of_sequence_id: looked up from "[CLS]" + * end_of_segment_id: looked up from "[SEP]" + * padding_id: looked up form "[PAD]" + * mask_id: looked up from "[MASK]" + * vocab_size: one past the largest token id used + """ + return self._special_tokens_dict + + def _create_special_tokens_dict(self): + """Creates dict of token ids, keyed by standard names for their purpose.""" + special_tokens = {"vocab_size": self.vocab_size} + + def add_special_token(key, token): + try: + token_id = self._vocab.index(token) + special_tokens[key] = token_id + except ValueError: + # Similar as nlp.modeling.layers.BertTokenizer, if a special token + # is not found, its entry is omitted from the dict. + logging.warning("Could not find %s as token \"%s\" in vocab file", key, + token) + + add_special_token("start_of_sequence_id", "[CLS]") + add_special_token("end_of_segment_id", "[SEP]") + add_special_token("padding_id", "[PAD]") + add_special_token("mask_id", "[MASK]") + return special_tokens + + def _tokenize_with_offsets(self, text_input: tf.Tensor): + tokens, begin, _ = self._basic_tokenizer.tokenize_with_offsets(text_input) + wordpieces, wp_begin, wp_end = ( + self._fast_wp_tokenizer.tokenize_with_offsets(tokens)) + begin_expanded = tf.expand_dims(begin, axis=2) + final_begin = begin_expanded + wp_begin + final_end = begin_expanded + wp_end + return wordpieces, final_begin, final_end + + def _tokenize(self, text_input: tf.Tensor): + tokens = self._basic_tokenizer.tokenize(text_input) + return self._fast_wp_tokenizer.tokenize(tokens) + + def call(self, inputs: tf.Tensor): + """Calls text.BertTokenizer on inputs. + + Args: + inputs: A string Tensor of shape [batch_size]. + + Returns: + One or three of RaggedTensors if tokenize_with_offsets is False or True, + respectively. These are + tokens: A RaggedTensor of shape [batch_size, (words), (pieces_per_word)] + and type int32. tokens[i,j,k] contains the k-th wordpiece of the + j-th word in the i-th input. + start_offsets, limit_offsets: If tokenize_with_offsets is True, + RaggedTensors of type int64 with the same indices as tokens. + Element [i,j,k] contains the byte offset at the start, or past the + end, resp., for the k-th wordpiece of the j-th word in the i-th input. + """ + # Prepare to reshape the result to work around broken shape inference. + batch_size = tf.shape(inputs)[0] + + def _reshape(rt): + values = rt.values + row_splits = rt.row_splits + row_splits = tf.reshape(row_splits, [batch_size + 1]) + return tf.RaggedTensor.from_row_splits(values, row_splits) + + if self.tokenize_with_offsets: + tokens, start_offsets, limit_offsets = self._tokenize_with_offsets(inputs) + return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets) + else: + tokens = self._tokenize(inputs) + return _reshape(tokens) diff --git a/official/nlp/modeling/layers/text_layers_test.py b/official/nlp/modeling/layers/text_layers_test.py index d25047a7a55d1c732e54b67a226d7de23d1709ff..0608863ca8f8e933005c760ef597e5a8200d666e 100644 --- a/official/nlp/modeling/layers/text_layers_test.py +++ b/official/nlp/modeling/layers/text_layers_test.py @@ -442,5 +442,109 @@ class BertPackInputsTest(tf.test.TestCase): [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]])) +# This test covers the in-process behavior of FastWordpieceBertTokenizer layer. +class FastWordPieceBertTokenizerTest(tf.test.TestCase): + + def _make_vocab_file(self, vocab, filename="vocab.txt"): + path = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time. + filename) + with tf.io.gfile.GFile(path, "w") as f: + f.write("\n".join(vocab + [""])) + return path + + def test_uncased(self): + vocab_file = self._make_vocab_file( + ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"]) + bert_tokenize = text_layers.FastWordpieceBertTokenizer( + vocab_file=vocab_file, lower_case=True) + inputs = tf.constant(["abc def", "ABC DEF d"]) + token_ids = bert_tokenize(inputs) + self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]], + [[6], [4, 5], [4]]])) + bert_tokenize.tokenize_with_offsets = True + token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs) + self.assertAllEqual(token_ids, token_ids_2) + self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]], + [[0], [4, 5], [8]]])) + self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]], + [[3], [5, 7], [9]]])) + self.assertEqual(bert_tokenize.vocab_size, 8) + + # Repeat the above and test that case matters with lower_case=False. + def test_cased(self): + vocab_file = self._make_vocab_file( + ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"]) + bert_tokenize = text_layers.FastWordpieceBertTokenizer( + vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True) + inputs = tf.constant(["abc def", "ABC DEF"]) + token_ids, start_offsets, limit_offsets = bert_tokenize(inputs) + self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]], + [[7], [1]]])) + self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]], + [[0], [4]]])) + self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]], + [[3], [7]]])) + + def test_special_tokens_complete(self): + vocab_file = self._make_vocab_file( + ["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"]) + bert_tokenize = text_layers.FastWordpieceBertTokenizer( + vocab_file=vocab_file, lower_case=True) + self.assertDictEqual(bert_tokenize.get_special_tokens_dict(), + dict(padding_id=1, + start_of_sequence_id=3, + end_of_segment_id=4, + mask_id=5, + vocab_size=7)) + + def test_special_tokens_partial(self): + # [UNK] token is required by fast wordpiece tokenizer. + vocab_file = self._make_vocab_file( + ["[PAD]", "[CLS]", "[SEP]", "[UNK]"]) + bert_tokenize = text_layers.FastWordpieceBertTokenizer( + vocab_file=vocab_file, lower_case=True) + self.assertDictEqual(bert_tokenize.get_special_tokens_dict(), + dict(padding_id=0, + start_of_sequence_id=1, + end_of_segment_id=2, + vocab_size=4)) # No mask_id, + + def test_special_tokens_in_estimator(self): + """Tests getting special tokens without an Eager init context.""" + vocab_file = self._make_vocab_file( + ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"]) + + def input_fn(): + with tf.init_scope(): + self.assertFalse(tf.executing_eagerly()) + # Build a preprocessing Model. + sentences = tf.keras.layers.Input(shape=[], dtype=tf.string) + bert_tokenizer = text_layers.FastWordpieceBertTokenizer( + vocab_file=vocab_file, lower_case=True) + special_tokens_dict = bert_tokenizer.get_special_tokens_dict() + for k, v in special_tokens_dict.items(): + self.assertIsInstance(v, int, "Unexpected type for {}".format(k)) + tokens = bert_tokenizer(sentences) + packed_inputs = text_layers.BertPackInputs( + 4, special_tokens_dict=special_tokens_dict)(tokens) + preprocessing = tf.keras.Model(sentences, packed_inputs) + # Map the dataset. + ds = tf.data.Dataset.from_tensors( + (tf.constant(["abc", "DEF"]), tf.constant([0, 1]))) + ds = ds.map(lambda features, labels: (preprocessing(features), labels)) + return ds + + def model_fn(features, labels, mode): + del labels # Unused. + return tf.estimator.EstimatorSpec(mode=mode, + predictions=features["input_word_ids"]) + + estimator = tf.estimator.Estimator(model_fn=model_fn) + outputs = list(estimator.predict(input_fn)) + self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], + [2, 4, 5, 3]])) + + if __name__ == "__main__": tf.test.main()