提交 46088c0a 编写于 作者: H Hui Zhang

elimiate attn transpose

上级 f9e3eaa0
...@@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
and `head * d_k == size` and `head * d_k == size`
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed # when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
...@@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册