From 70142d0c66fe6675fe94d296c9d6c70768283687 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 7 Jul 2021 09:48:54 +0000 Subject: [PATCH] view to reshape --- deepspeech/modules/attention.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index 95ea8ebe..aba11a02 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -177,9 +177,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) - x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] + x_padded = x_padded.reshape( + [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) + x = x_padded[:, :, 1:].reshape(x.shape) # [B, H, T1, T1] if zero_triu: ones = paddle.ones((x.shape[2], x.shape[3])) @@ -209,7 +209,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = self.linear_pos(pos_emb).reshape( + [n_batch_pos, -1, self.h, self.d_k]) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) -- GitLab