Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
fb40602d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb40602d
编写于
7月 08, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor attention cache
上级
e1534955
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
267 addition
and
176 deletion
+267
-176
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+30
-17
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+109
-21
paddlespeech/s2t/modules/conformer_convolution.py
paddlespeech/s2t/modules/conformer_convolution.py
+10
-8
paddlespeech/s2t/modules/decoder_layer.py
paddlespeech/s2t/modules/decoder_layer.py
+4
-4
paddlespeech/s2t/modules/embedding.py
paddlespeech/s2t/modules/embedding.py
+1
-1
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+57
-58
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+54
-65
setup.py
setup.py
+2
-2
未找到文件。
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
fb40602d
...
...
@@ -605,29 +605,42 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs
:
paddle
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
subsampling_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
elayers_output_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
conformer_cnn_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
List
[
paddle
.
Tensor
],
List
[
paddle
.
Tensor
]]:
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
paddle.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, T(?), d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
# @jit.to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
...
paddlespeech/s2t/modules/attention.py
浏览文件 @
fb40602d
...
...
@@ -84,9 +84,10 @@ class MultiHeadedAttention(nn.Layer):
return
q
,
k
,
v
def
forward_attention
(
self
,
value
:
paddle
.
Tensor
,
scores
:
paddle
.
Tensor
,
mask
:
Optional
[
paddle
.
Tensor
])
->
paddle
.
Tensor
:
value
:
paddle
.
Tensor
,
scores
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
)
->
paddle
.
Tensor
:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
...
...
@@ -94,14 +95,23 @@ class MultiHeadedAttention(nn.Layer):
scores (paddle.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2).
(#batch, time1, time2)
, (0, 0, 0) means fake mask
.
Returns:
paddle.Tensor: Transformed value
weighted
by the attention score, (#batch, time1, d_model
).
paddle.Tensor: Transformed value
(#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2
).
"""
n_batch
=
value
.
shape
[
0
]
if
mask
is
not
None
:
# When `if mask.size(2) > 0` be True:
# 1. training.
# 2. oonx(16/4, chunk_size/history_size), feed real cache and real mask for the 1st chunk.
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
if
paddle
.
shape
(
mask
)[
2
]
>
0
:
# time2 > 0
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask
=
mask
[:,
:,
:,
:
paddle
.
shape
(
scores
)[
-
1
]]
scores
=
scores
.
masked_fill
(
mask
,
-
float
(
'inf'
))
attn
=
paddle
.
softmax
(
scores
,
axis
=-
1
).
masked_fill
(
mask
,
...
...
@@ -121,21 +131,67 @@ class MultiHeadedAttention(nn.Layer):
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
mask
:
Optional
[
paddle
.
Tensor
])
->
paddle
.
Tensor
:
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute scaled dot product attention.
Args:
query (
torch
.Tensor): Query tensor (#batch, time1, size).
key (
torch
.Tensor): Key tensor (#batch, time2, size).
value (
torch
.Tensor): Value tensor (#batch, time2, size).
mask (
torch
.Tensor): Mask tensor (#batch, 1, time2) or
Args:
query (
paddle
.Tensor): Query tensor (#batch, time1, size).
key (
paddle
.Tensor): Key tensor (#batch, time2, size).
value (
paddle
.Tensor): Value tensor (#batch, time2, size).
mask (
paddle
.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
paddle
.
shape
(
cache
)[
0
]
>
0
:
# last dim `d_k * 2` for (key, val)
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
paddle
.
shape
(
cache
)[
-
1
]
//
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
scores
=
paddle
.
matmul
(
q
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
,
new_cache
class
RelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
...
...
@@ -192,23 +248,55 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
mask
:
Optional
[
paddle
.
Tensor
]):
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time1, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
q
=
q
.
transpose
([
0
,
2
,
1
,
3
])
# (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
paddle
.
shape
(
cache
)[
0
]
>
0
:
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
paddle
.
shape
(
cache
)[
-
1
]
//
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
n_batch_pos
=
pos_emb
.
shape
[
0
]
p
=
self
.
linear_pos
(
pos_emb
).
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
([
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
...
...
@@ -234,4 +322,4 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
,
new_cache
\ No newline at end of file
paddlespeech/s2t/modules/conformer_convolution.py
浏览文件 @
fb40602d
...
...
@@ -108,15 +108,17 @@ class ConvolutionModule(nn.Layer):
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask_pad
:
Optional
[
paddle
.
Tensor
]
=
None
,
cache
:
Optional
[
paddle
.
Tensor
]
=
None
mask_pad
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
)
,
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
]),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
mask_pad (paddle.Tensor): used for batch padding, (#batch, channels, time).
mask_pad (paddle.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (paddle.Tensor): left context cache, it is only
used in causal convolution. (#batch, channels, time')
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time')
...
...
@@ -125,11 +127,11 @@ class ConvolutionModule(nn.Layer):
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, C, T]
# mask batch padding
if
mask_pad
is
not
None
:
if
paddle
.
shape
(
mask_pad
)[
2
]
>
0
:
# time > 0
x
=
x
.
masked_fill
(
mask_pad
,
0.0
)
if
self
.
lorder
>
0
:
if
cache
is
None
:
if
paddle
.
shape
(
cache
)[
2
]
==
0
:
# cache_t == 0
x
=
nn
.
functional
.
pad
(
x
,
[
self
.
lorder
,
0
],
'constant'
,
0.0
,
data_format
=
'NCL'
)
else
:
...
...
@@ -143,7 +145,7 @@ class ConvolutionModule(nn.Layer):
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtype
)
new_cache
=
paddle
.
zeros
([
0
,
0
,
0
],
dtype
=
x
.
dtype
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
...
...
@@ -159,7 +161,7 @@ class ConvolutionModule(nn.Layer):
x
=
self
.
pointwise_conv2
(
x
)
# mask batch padding
if
mask_pad
is
not
None
:
if
paddle
.
shape
(
mask_pad
)[
2
]
>
0
:
# time > 0
x
=
x
.
masked_fill
(
mask_pad
,
0.0
)
x
=
x
.
transpose
([
0
,
2
,
1
])
# [B, T, C]
...
...
paddlespeech/s2t/modules/decoder_layer.py
浏览文件 @
fb40602d
...
...
@@ -121,11 +121,11 @@ class DecoderLayer(nn.Layer):
if
self
.
concat_after
:
tgt_concat
=
paddle
.
cat
(
(
tgt_q
,
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
)),
dim
=-
1
)
(
tgt_q
,
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
)
[
0
]
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear1
(
tgt_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
))
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
)
[
0
]
)
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
...
...
@@ -134,11 +134,11 @@ class DecoderLayer(nn.Layer):
x
=
self
.
norm2
(
x
)
if
self
.
concat_after
:
x_concat
=
paddle
.
cat
(
(
x
,
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
)),
dim
=-
1
)
(
x
,
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
)
[
0
]
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear2
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
))
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
)
[
0
]
)
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
...
...
paddlespeech/s2t/modules/embedding.py
浏览文件 @
fb40602d
...
...
@@ -131,7 +131,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding
paddle.Tensor: Corresponding position encoding
, #[1, T, D].
"""
assert
offset
+
size
<
self
.
max_len
return
self
.
dropout
(
self
.
pe
[:,
offset
:
offset
+
size
])
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
fb40602d
...
...
@@ -177,7 +177,7 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size
,
self
.
static_chunk_size
,
num_decoding_left_chunks
)
for
layer
in
self
.
encoders
:
xs
,
chunk_masks
,
_
=
layer
(
xs
,
chunk_masks
,
pos_emb
,
mask_pad
)
xs
,
chunk_masks
,
_
,
_
=
layer
(
xs
,
chunk_masks
,
pos_emb
,
mask_pad
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
# Here we assume the mask is not changed in encoder layers, so just
...
...
@@ -190,30 +190,31 @@ class BaseEncoder(nn.Layer):
xs
:
paddle
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
subsampling_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
elayers_output_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
conformer_cnn_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
List
[
paddle
.
Tensor
],
List
[
paddle
.
Tensor
]]:
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
att_mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
""" Forward just one chunk
Args:
xs (paddle.Tensor): chunk input, [B=1, T, D]
xs (paddle.Tensor): chunk audio feat input, [B=1, T, D], where
`T==(chunk_size-1)*subsampling_rate + subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
att_cache(paddle.Tensor): cache tensor for key & val in
transformer/conformer attention. Shape is
(elayers, head, cache_t1, d_k * 2), where`head * d_k == hidden-dim`
and `cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, B=1, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output of current input xs
paddle.Tensor:
subsampling cache required for next chunk computation
List[paddle.Tensor]: encoder layers output cache required for next
chunk computation
List[paddle.Tensor]: conformer cnn
cache
paddle.Tensor: output of current input xs
, (B=1, chunk_size, hidden-dim)
paddle.Tensor:
new attention cache required for next chunk, dyanmic shape
(elayers, head, T, d_k*2) depending on required_cache_size
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_
cache
"""
assert
xs
.
shape
[
0
]
==
1
# batch size must be one
# tmp_masks is just for interface compatibility
...
...
@@ -225,50 +226,49 @@ class BaseEncoder(nn.Layer):
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
#xs=(B, T, D), pos_emb=(B=1, T, D)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
# after embed, xs=(B=1, chunk_size, hidden-dim)
if
subsampling_cache
is
not
None
:
cache_size
=
subsampling_cache
.
shape
[
1
]
#T
xs
=
paddle
.
cat
((
subsampling_cache
,
xs
),
dim
=
1
)
else
:
cache_size
=
0
elayers
,
cache_t1
=
paddle
.
shape
(
att_cache
)[
0
],
paddle
.
shape
(
att_cache
)[
2
]
chunk_size
=
paddle
.
shape
(
xs
)[
1
]
attention_key_size
=
cache_t1
+
chunk_size
# only used when using `RelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_
size
,
size
=
xs
.
shape
[
1
]
)
offset
=
offset
-
cache_
t1
,
size
=
attention_key_size
)
if
required_cache_size
<
0
:
next_cache_start
=
0
elif
required_cache_size
==
0
:
next_cache_start
=
xs
.
shape
[
1
]
next_cache_start
=
attention_key_size
else
:
next_cache_start
=
xs
.
shape
[
1
]
-
required_cache_size
r_subsampling_cache
=
xs
[:,
next_cache_start
:,
:]
# Real mask for transformer/conformer layers
masks
=
paddle
.
ones
([
1
,
xs
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
#[B=1, L'=1, T]
r_elayers_output_cache
=
[]
r_conformer_cnn_cache
=
[]
next_cache_start
=
max
(
attention_key_size
-
required_cache_size
,
0
)
r_att_cache
=
[]
r_cnn_cache
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoders
):
attn_cache
=
None
if
elayers_output_cache
is
None
else
elayers_output_cache
[
i
]
cnn_cache
=
None
if
conformer_cnn_cache
is
None
else
conformer_cnn_cache
[
i
]
xs
,
_
,
new_cnn_cache
=
layer
(
xs
,
masks
,
pos_emb
,
output_cache
=
attn_cache
,
cnn_cache
=
cnn_cache
)
r_
elayers_output_cache
.
append
(
xs
[:,
next_cache_start
:,
:])
r_conformer_cnn_cache
.
append
(
new_cnn_cache
)
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i] = (B=1, hidden-dim, cache_t2)
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
pos_emb
,
att_cache
=
att_cache
[
i
:
i
+
1
]
if
elayers
>
0
else
att_cache
,
cnn_cache
=
cnn_cache
[
i
]
if
paddle
.
shape
(
cnn_cache
)[
0
]
>
0
else
cnn_cache
,
)
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache
.
append
(
new_att_cache
[:,:,
next_cache_start
:,
:]
)
r_
cnn_cache
.
append
(
new_cnn_cache
.
unsqueeze
(
0
))
# add elayer dim
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
(
xs
[:,
cache_size
:,
:],
r_subsampling_cache
,
r_elayers_output_cache
,
r_conformer_cnn_cache
)
# r_att_cache (elayers, head, T, d_k*2)
# r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
r_att_cache
=
paddle
.
concat
(
r_att_cache
,
axis
=
0
)
r_cnn_cache
=
paddle
.
concat
(
r_cnn_cache
,
axis
=
0
)
return
xs
,
r_att_cache
,
r_cnn_cache
def
forward_chunk_by_chunk
(
self
,
...
...
@@ -313,25 +313,24 @@ class BaseEncoder(nn.Layer):
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
subsampling_cache
:
Optional
[
paddle
.
Tensor
]
=
None
elayers_output_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
conformer_cnn_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
outputs
=
[]
offset
=
0
# Feed forward overlap input step by step
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
subsampling_cache
,
elayers_output_cache
,
conformer_
cnn_cache
)
=
self
.
forward_chunk
(
chunk_xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
(
y
,
att_cache
,
cnn_cache
)
=
self
.
forward_chunk
(
chunk_xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
outputs
.
append
(
y
)
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
# fake mask, just for jit script and compatibility with `forward` api
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
masks
=
paddle
.
ones
([
1
,
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
return
ys
,
masks
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
fb40602d
...
...
@@ -75,49 +75,45 @@ class TransformerEncoderLayer(nn.Layer):
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
Optional
[
paddle
.
Tensor
]
=
None
,
mask_pad
:
Optional
[
paddle
.
Tensor
]
=
None
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
pos_emb
:
paddle
.
Tensor
,
mask_pad
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
)
,
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
,
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: Fake cnn cache tensor for api compatibility with Conformer (#batch, channels, time').
paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
cache
=
att_cache
)
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
(
(
x
,
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)),
axis
=-
1
)
x_concat
=
paddle
.
concat
((
x
,
x_att
),
axis
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)
)
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
...
...
@@ -128,11 +124,8 @@ class TransformerEncoderLayer(nn.Layer):
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
fake_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtype
)
return
x
,
mask
,
fake_cnn_cache
fake_cnn_cache
=
paddle
.
zeros
([
0
,
0
,
0
],
dtype
=
x
.
dtype
)
return
x
,
mask
,
new_att_cache
,
fake_cnn_cache
class
ConformerEncoderLayer
(
nn
.
Layer
):
...
...
@@ -192,32 +185,41 @@ class ConformerEncoderLayer(nn.Layer):
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_linear
=
Linear
(
size
+
size
,
size
)
if
self
.
concat_after
:
self
.
concat_linear
=
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear
=
nn
.
Identity
()
def
forward
(
self
,
x
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
pos_emb
:
paddle
.
Tensor
,
mask_pad
:
Optional
[
paddle
.
Tensor
]
=
None
,
output_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
cnn_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
mask_pad
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
)
,
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
,
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute encoded features.
Args:
x (paddle.Tensor): (#batch, time, size)
mask (paddle.Tensor): Mask tensor for the input (#batch, time,time).
pos_emb (paddle.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (paddle.Tensor): batch padding mask used for conv module, (B, 1, T).
output_cache (paddle.Tensor): Cache tensor of the encoder output
(#batch, time2, size), time2 < time in x.
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time, time).
(0,0,0) means fake mask.
pos_emb (paddle.Tensor): postional encoding, must not be None
for ConformerEncoderLayer
mask_pad (paddle.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: New cnn cache tensor (#batch, channels, time').
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style FFN
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
...
...
@@ -233,18 +235,8 @@ class ConformerEncoderLayer(nn.Layer):
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
if
output_cache
is
None
:
x_q
=
x
else
:
assert
output_cache
.
shape
[
0
]
==
x
.
shape
[
0
]
assert
output_cache
.
shape
[
1
]
<
x
.
shape
[
1
]
assert
output_cache
.
shape
[
2
]
==
self
.
size
chunk
=
x
.
shape
[
1
]
-
output_cache
.
shape
[
1
]
x_q
=
x
[:,
-
chunk
:,
:]
residual
=
residual
[:,
-
chunk
:,
:]
mask
=
mask
[:,
-
chunk
:,
:]
x_att
=
self
.
self_attn
(
x_q
,
x
,
x
,
pos_emb
,
mask
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
,
cache
=
att_cache
)
if
self
.
concat_after
:
x_concat
=
paddle
.
concat
((
x
,
x_att
),
axis
=-
1
)
...
...
@@ -257,7 +249,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtype
)
new_cnn_cache
=
paddle
.
zeros
([
0
,
0
,
0
],
dtype
=
x
.
dtype
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
...
...
@@ -282,7 +274,4 @@ class ConformerEncoderLayer(nn.Layer):
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
output_cache
is
not
None
:
x
=
paddle
.
concat
([
output_cache
,
x
],
axis
=
1
)
return
x
,
mask
,
new_cnn_cache
return
x
,
mask
,
new_att_cache
,
new_cnn_cache
\ No newline at end of file
setup.py
浏览文件 @
fb40602d
...
...
@@ -71,7 +71,8 @@ base = [
"colorlog"
,
"pathos == 0.2.8"
,
"braceexpand"
,
"pyyaml"
"pyyaml"
,
"pybind11"
,
]
server
=
[
...
...
@@ -90,7 +91,6 @@ requirements = {
"gpustat"
,
"paddlespeech_ctcdecoders"
,
"phkit"
,
"pybind11"
,
"pypi-kenlm"
,
"snakeviz"
,
"sox"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录