提交 925abcca 编写于 作者: H Hui Zhang

format

上级 2a75405e
...@@ -19,8 +19,8 @@ from typing import Tuple ...@@ -19,8 +19,8 @@ from typing import Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import initializer as I
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
...@@ -56,12 +56,12 @@ class MultiHeadedAttention(nn.Layer): ...@@ -56,12 +56,12 @@ class MultiHeadedAttention(nn.Layer):
self.linear_out = Linear(n_feat, n_feat) self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
def _build_once(self, *args, **kwargs): def _build_once(self, *args, **kwargs):
super()._build_once(*args, **kwargs) super()._build_once(*args, **kwargs)
# if self.self_att: # if self.self_att:
# self.linear_kv = Linear(self.n_feat, self.n_feat*2) # self.linear_kv = Linear(self.n_feat, self.n_feat*2)
self.weight = paddle.concat([self.linear_k.weight, self.linear_v.weight], axis=-1) self.weight = paddle.concat(
[self.linear_k.weight, self.linear_v.weight], axis=-1)
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias]) self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
self._built = True self._built = True
...@@ -84,12 +84,14 @@ class MultiHeadedAttention(nn.Layer): ...@@ -84,12 +84,14 @@ class MultiHeadedAttention(nn.Layer):
(#batch, n_head, time2, d_k). (#batch, n_head, time2, d_k).
""" """
n_batch = query.shape[0] n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
# k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) # k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
# v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) # v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
k, v = F.linear(key, self.weight, self.bias).view(n_batch, -1, 2 * self.h, self.d_k).split(2, axis=2) k, v = F.linear(key, self.weight, self.bias).view(
n_batch, -1, 2 * self.h, self.d_k).split(
2, axis=2)
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
...@@ -203,7 +205,7 @@ class MultiHeadedAttention(nn.Layer): ...@@ -203,7 +205,7 @@ class MultiHeadedAttention(nn.Layer):
new_cache = paddle.concat((k, v), axis=-1) new_cache = paddle.concat((k, v), axis=-1)
# scores = paddle.matmul(q, # scores = paddle.matmul(q,
# k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) # k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache return self.forward_attention(v, scores, mask), new_cache
......
...@@ -221,7 +221,7 @@ class BaseEncoder(nn.Layer): ...@@ -221,7 +221,7 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim) # after embed, xs=(B=1, chunk_size, hidden-dim)
elayers, _, cache_t1, _ = att_cache.shape elayers, _, cache_t1, _ = att_cache.shape
chunk_size = xs.shape[1] chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size attention_key_size = cache_t1 + chunk_size
......
...@@ -110,7 +110,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: ...@@ -110,7 +110,7 @@ def subsequent_mask(size: int) -> paddle.Tensor:
""" """
ret = paddle.ones([size, size], dtype=paddle.bool) ret = paddle.ones([size, size], dtype=paddle.bool)
return paddle.tril(ret) return paddle.tril(ret)
def subsequent_chunk_mask( def subsequent_chunk_mask(
size: int, size: int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册