提交 46bf5016 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

[kernel] Improve readability by letting the user of cache to do the decay.

PiperOrigin-RevId: 477359324
上级 4037c9e7
......@@ -260,10 +260,11 @@ def causal_windowed_performer_attention(query_matrix,
if window_decay is None or window_decay > 1.0 or window_decay < 0.0:
raise ValueError("window_decay should be in (0.0, 1.0) and not None.")
kv = cache["kv"] + tf.einsum("BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv * window_decay
k_sum = cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum * window_decay
kv = window_decay * cache["kv"] + tf.einsum(
"BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv
k_sum = window_decay * cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册