提交 d6916533 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

set `tf.expand_dims argument` `axis` to an integer instead of a list in the...

set `tf.expand_dims argument` `axis` to an integer instead of a list in the multi_head_attention model.

PiperOrigin-RevId: 380957913
上级 4b91b6c3
......@@ -424,10 +424,10 @@ class MultiHeadAttention(Layer):
if attention_mask is not None:
# The expand dim happens starting from the `num_heads` dimension,
# (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
mask_expansion_axes = [-len(self._attention_axes) * 2 - 1]
mask_expansion_axis = -len(self._attention_axes) * 2 - 1
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
attention_mask = tf.expand_dims(
attention_mask, axis=mask_expansion_axes)
attention_mask, axis=mask_expansion_axis)
return self._softmax(attention_scores, attention_mask)
def _compute_attention(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册