diff --git a/official/projects/lra/exponential_moving_average.py b/official/projects/lra/exponential_moving_average.py index 0ca7db0befd43078d20d47c4f7e8ed58523e0fbc..221405d319864d07d3ce4239488296a9788a653d 100644 --- a/official/projects/lra/exponential_moving_average.py +++ b/official/projects/lra/exponential_moving_average.py @@ -31,8 +31,9 @@ class MultiHeadEMA(tf.keras.layers.Layer): ndim=2, bidirectional=False, truncation=None, + **kwargs ): - super().__init__() + super().__init__(**kwargs) self.embed_dim = embed_dim self.ndim = ndim diff --git a/official/projects/lra/mega_encoder.py b/official/projects/lra/mega_encoder.py index 16d386e8a6e885d9ada83d7901663f6ee2ed7557..619f5ac2671dd56ca37eb717594f7b1aebce3bf0 100644 --- a/official/projects/lra/mega_encoder.py +++ b/official/projects/lra/mega_encoder.py @@ -28,6 +28,7 @@ _Initializer = Union[str, tf.keras.initializers.Initializer] _approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True) +@tf.keras.utils.register_keras_serializable(package='Text') class MegaEncoder(tf.keras.layers.Layer): """MegaEncoder. diff --git a/official/projects/lra/mega_encoder_test.py b/official/projects/lra/mega_encoder_test.py new file mode 100644 index 0000000000000000000000000000000000000000..907256e69764e64ee5e58308bd806655549d4d4b --- /dev/null +++ b/official/projects/lra/mega_encoder_test.py @@ -0,0 +1,43 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for official.nlp.projects.lra.mega_encoder.""" + +import numpy as np +import tensorflow as tf + +from official.projects.lra import mega_encoder + + +class MegaEncoderTest(tf.test.TestCase): + + def test_encoder(self): + sequence_length = 1024 + batch_size = 2 + vocab_size = 1024 + network = mega_encoder.MegaEncoder( + num_layers=1, vocab_size=1024, max_sequence_length=4096) + word_id_data = np.random.randint( + vocab_size, size=(batch_size, sequence_length)) + mask_data = np.random.randint(2, size=(batch_size, sequence_length)) + type_id_data = np.random.randint(2, size=(batch_size, sequence_length)) + outputs = network({"input_word_ids": word_id_data, + "input_mask": mask_data, + "input_type_ids": type_id_data}) + self.assertEqual(outputs["sequence_output"].shape, + (batch_size, sequence_length, 128)) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file diff --git a/official/projects/lra/moving_average_gated_attention.py b/official/projects/lra/moving_average_gated_attention.py index d7fbc43b9845396230c5098b6245c7e4cb784cbd..b163e6aa38dd1e32705abe54b56847b83a5f9a32 100644 --- a/official/projects/lra/moving_average_gated_attention.py +++ b/official/projects/lra/moving_average_gated_attention.py @@ -48,8 +48,6 @@ class RelativePositionBias(tf.keras.layers.Layer): def call(self, seq_len): if seq_len is None: seq_len = self.max_positions - #import pdb - #pdb.set_trace() seq_len = tf.get_static_value(seq_len) # seq_len * 2 -1 b = self.rel_pos_bias[(self.max_positions - seq_len):(self.max_positions + @@ -199,7 +197,8 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer): super().build(input_shape) def get_config(self): - config = { + base_config = super().get_config() + base_config.update({ "embed_dim": self.embed_dim, "zdim": @@ -226,11 +225,10 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer): self._attention_axes, "return_attention_scores": self.return_attention_scores - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + }) + return base_config - def softmax_attention(self, q, k): + def _softmax_attention(self, q, k): slen = k.shape[1] # C x C if slen is None: @@ -246,19 +244,17 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer): return attn_weights def call(self, inputs: Any) -> Any: - """MEGA encoder block call. - - Args: - inputs: a single tensor or a list of tensors. `input tensor` + """ + MEGA encoder block call. + Args: inputs: a single tensor or a list of tensors. `input tensor` as the single sequence of embeddings. [`input tensor`, `attention mask`] to have the additional attention mask. [`query tensor`, `key value tensor`, `attention mask`] to have separate input streams for the query, and key/value to the multi-head attention. - Returns: - An output tensor with the same dimensions as input/query tensor. - """ + An output tensor with the same dimensions as input/query tensor. + """ if isinstance(inputs, (list, tuple)): if len(inputs) == 2: (input_tensor, attention_mask) = inputs @@ -310,7 +306,7 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer): # L x B x E -> B x L x E v = tf.transpose(v, perm=(1, 0, 2)) - attn_weights = self.softmax_attention(q, k) + attn_weights = self._softmax_attention(q, k) v = self.hidden_dropout(v) kernel = tf.squeeze(self.attention_dropout(attn_weights)) # B x K x C x E -> B x L x E -> L x B x E