提交 34f363ee 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

[kernel] Update test to make usage clear

PiperOrigin-RevId: 480718252
上级 b9e0e14f
......@@ -68,10 +68,10 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
value=value,
attention_mask=masks,
training=training)
dim = num_random_features if num_random_features > 0 else key_dim
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, num_heads, key_dim))
(batch_size, num_heads, dim, dim))
k_sum_cache = tf.zeros((batch_size, num_heads, dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册