From d6916533c01a175012638c33014318f0137a9a21 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 22 Jun 2021 22:06:09 -0700 Subject: [PATCH] set `tf.expand_dims argument` `axis` to an integer instead of a list in the multi_head_attention model. PiperOrigin-RevId: 380957913 --- keras/layers/multi_head_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/layers/multi_head_attention.py b/keras/layers/multi_head_attention.py index 0721a4218..91b0909f9 100644 --- a/keras/layers/multi_head_attention.py +++ b/keras/layers/multi_head_attention.py @@ -424,10 +424,10 @@ class MultiHeadAttention(Layer): if attention_mask is not None: # The expand dim happens starting from the `num_heads` dimension, # (, num_heads, ) - mask_expansion_axes = [-len(self._attention_axes) * 2 - 1] + mask_expansion_axis = -len(self._attention_axes) * 2 - 1 for _ in range(len(attention_scores.shape) - len(attention_mask.shape)): attention_mask = tf.expand_dims( - attention_mask, axis=mask_expansion_axes) + attention_mask, axis=mask_expansion_axis) return self._softmax(attention_scores, attention_mask) def _compute_attention(self, -- GitLab