提交 223776a4 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

[kernel] fix test head shape. This does not cause an error because we overwirte the cache.

PiperOrigin-RevId: 479993154
上级 0119344c
......@@ -71,7 +71,7 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, 1, key_dim))
k_sum_cache = tf.zeros((batch_size, num_heads, key_dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册