提交 a108f087 编写于 作者: Y Yufan Zhuang

update for PR comments

上级 01cf744f
......@@ -31,8 +31,9 @@ class MultiHeadEMA(tf.keras.layers.Layer):
ndim=2,
bidirectional=False,
truncation=None,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.ndim = ndim
......
......@@ -28,6 +28,7 @@ _Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
@tf.keras.utils.register_keras_serializable(package='Text')
class MegaEncoder(tf.keras.layers.Layer):
"""MegaEncoder.
......
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.lra.mega_encoder."""
import numpy as np
import tensorflow as tf
from official.projects.lra import mega_encoder
class MegaEncoderTest(tf.test.TestCase):
def test_encoder(self):
sequence_length = 1024
batch_size = 2
vocab_size = 1024
network = mega_encoder.MegaEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(2, size=(batch_size, sequence_length))
outputs = network({"input_word_ids": word_id_data,
"input_mask": mask_data,
"input_type_ids": type_id_data})
self.assertEqual(outputs["sequence_output"].shape,
(batch_size, sequence_length, 128))
if __name__ == "__main__":
tf.test.main()
\ No newline at end of file
......@@ -48,8 +48,6 @@ class RelativePositionBias(tf.keras.layers.Layer):
def call(self, seq_len):
if seq_len is None:
seq_len = self.max_positions
#import pdb
#pdb.set_trace()
seq_len = tf.get_static_value(seq_len)
# seq_len * 2 -1
b = self.rel_pos_bias[(self.max_positions - seq_len):(self.max_positions +
......@@ -199,7 +197,8 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer):
super().build(input_shape)
def get_config(self):
config = {
base_config = super().get_config()
base_config.update({
"embed_dim":
self.embed_dim,
"zdim":
......@@ -226,11 +225,10 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer):
self._attention_axes,
"return_attention_scores":
self.return_attention_scores
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
})
return base_config
def softmax_attention(self, q, k):
def _softmax_attention(self, q, k):
slen = k.shape[1]
# C x C
if slen is None:
......@@ -246,19 +244,17 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer):
return attn_weights
def call(self, inputs: Any) -> Any:
"""MEGA encoder block call.
Args:
inputs: a single tensor or a list of tensors. `input tensor`
"""
MEGA encoder block call.
Args: inputs: a single tensor or a list of tensors. `input tensor`
as the single sequence of embeddings. [`input tensor`,
`attention mask`] to have the
additional attention mask. [`query tensor`, `key value tensor`,
`attention mask`] to have separate input streams for the query, and
key/value to the multi-head attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
An output tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
(input_tensor, attention_mask) = inputs
......@@ -310,7 +306,7 @@ class MovingAverageGatedAttention(tf.keras.layers.Layer):
# L x B x E -> B x L x E
v = tf.transpose(v, perm=(1, 0, 2))
attn_weights = self.softmax_attention(q, k)
attn_weights = self._softmax_attention(q, k)
v = self.hidden_dropout(v)
kernel = tf.squeeze(self.attention_dropout(attn_weights))
# B x K x C x E -> B x L x E -> L x B x E
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册