提交 70142d0c 编写于 作者: H Hui Zhang

view to reshape

上级 b8997058
...@@ -177,9 +177,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -177,9 +177,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x_padded = x_padded.reshape(
x.shape[2]) [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] x = x_padded[:, :, 1:].reshape(x.shape) # [B, H, T1, T1]
if zero_triu: if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3])) ones = paddle.ones((x.shape[2], x.shape[3]))
...@@ -209,7 +209,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -209,7 +209,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.shape[0] n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册