提交 f9e3eaa0 编写于 作者: H Hui Zhang

transpose in matmul

上级 3d7ca938
......@@ -188,8 +188,9 @@ class MultiHeadedAttention(nn.Layer):
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
scores = paddle.matmul(q,
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
# scores = paddle.matmul(q,
# k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
......@@ -309,11 +310,13 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册