From 5e32139e3a5cc391bc81cae83942e3ed7c901cc2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 5 Jul 2021 10:47:18 +0000 Subject: [PATCH] view to reshape --- deepspeech/modules/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index afc70214..4ff57a94 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -71,9 +71,9 @@ class MultiHeadedAttention(nn.Layer): (#batch, n_head, time2, d_k). """ n_batch = query.shape[0] - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k]) + k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k]) + v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k]) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) @@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).contiguous().view( - n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).contiguous().reshape([ + n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) -- GitLab