提交 2716fc32 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 507443799
上级 26788724
......@@ -266,8 +266,14 @@ def causal_windowed_performer_attention(query_matrix,
k_sum = window_decay * cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
# The below is equivalent to but converts to TF Lite better than:
# tf.einsum("BTHD,BTH->BTHD",
# query_matrix, 1.0 / (denominator + _NUMERIC_STABLER))
inverse_denominator = 1.0 / (denominator + _NUMERIC_STABLER)
# Add another dimension to align for the broadcast multiplication.
fused_query_denominator = query_matrix * tf.expand_dims(inverse_denominator,
-1)
attention = tf.einsum("BTHD,BHOD->BTHO", fused_query_denominator, kv)
return attention
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册