diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index 95ea8ebec656ad9f18bb282cc8ed50bba26cdfe4..aba11a02961a8d5597df073bd65aeff6c551edd3 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -177,9 +177,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) - x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] + x_padded = x_padded.reshape( + [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) + x = x_padded[:, :, 1:].reshape(x.shape) # [B, H, T1, T1] if zero_triu: ones = paddle.ones((x.shape[2], x.shape[3])) @@ -209,7 +209,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = self.linear_pos(pos_emb).reshape( + [n_batch_pos, -1, self.h, self.d_k]) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k)