diff --git a/official/nlp/modeling/layers/kernel_attention_test.py b/official/nlp/modeling/layers/kernel_attention_test.py index 064f504d8bda320edd03c15a6e7a7f7c1a117bda..fa86b71b96b0c48f93d9ec947dd64e9b04e8f059 100644 --- a/official/nlp/modeling/layers/kernel_attention_test.py +++ b/official/nlp/modeling/layers/kernel_attention_test.py @@ -68,10 +68,10 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): value=value, attention_mask=masks, training=training) + dim = num_random_features if num_random_features > 0 else key_dim kv_cache = tf.zeros( - (batch_size, num_heads, key_dim, - num_random_features if num_random_features > 0 else key_dim)) - k_sum_cache = tf.zeros((batch_size, num_heads, key_dim)) + (batch_size, num_heads, dim, dim)) + k_sum_cache = tf.zeros((batch_size, num_heads, dim)) stream_output = [] cache = {"kv": kv_cache, "k_sum": k_sum_cache} for i in range(num_chunks):