Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
55870ffb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
9 个月 前同步成功
通知
200
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
55870ffb
编写于
7月 12, 2023
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
03e9ea9e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
19 addition
and
14 deletion
+19
-14
examples/aishell/asr1/conf/chunk_roformer.yaml
examples/aishell/asr1/conf/chunk_roformer.yaml
+3
-3
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
+1
-1
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+2
-1
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+8
-5
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+5
-4
未找到文件。
examples/aishell/asr1/conf/chunk_roformer.yaml
浏览文件 @
55870ffb
...
...
@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
r
po
e_pos'
# abs_pos, rel_pos, rope_pos
pos_enc_layer_type
:
'
r
op
e_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
...
...
@@ -30,7 +30,7 @@ decoder_conf:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
r_num_blocks
:
3
# only for bitransformer
r_num_blocks
:
0
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
...
...
@@ -39,7 +39,7 @@ decoder_conf:
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.
3
# only for bitransformer
reverse_weight
:
0.
0
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
...
...
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
浏览文件 @
55870ffb
...
...
@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
r
po
e_pos'
# abs_pos, rel_pos, rope_pos
pos_enc_layer_type
:
'
r
op
e_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
55870ffb
...
...
@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths
)
ctc_time
=
time
.
time
()
-
start
#logger.debug(f"ctc time: {ctc_time}")
if
loss_ctc
is
None
:
loss
=
loss_att
elif
loss_att
is
None
:
...
...
@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
if
decoder_type
==
'transformer'
:
configs
[
'model_conf'
].
pop
(
'reverse_weight'
,
None
)
configs
[
'decoder_conf'
].
pop
(
'r_num_blocks'
,
None
)
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
...
...
paddlespeech/s2t/modules/attention.py
浏览文件 @
55870ffb
...
...
@@ -16,6 +16,7 @@
"""Multi-Head Attention layer definition."""
import
math
from
typing
import
Tuple
from
typing
import
List
import
paddle
from
paddle
import
nn
...
...
@@ -418,25 +419,27 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
def
apply_rotary_position_embeddings
(
self
,
sinusoidal
,
*
tensors
):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,
T,H
,D/H)
tensor.shape=[B, T, ..., D], or (B,
H,T
,D/H)
"""
assert
len
(
tensors
)
>
0
,
'at least one input tensor'
assert
all
(
[
tensor
.
shape
==
tensors
[
0
].
shape
for
tensor
in
tensors
[
1
:]]),
'all tensors must have the same shape'
# (B,H,T,D)
ndim
=
tensors
[
0
].
dim
()
_
,
H
,
T
,
D
=
tensors
[
0
].
shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal
=
self
.
align
(
sinusoidal
,
[
0
,
1
,
-
1
],
ndim
)
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal
=
sinusoidal
.
reshape
((
1
,
T
,
H
,
D
)).
transpose
([
0
,
2
,
1
,
3
])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
1
::
2
],
2
,
axis
=-
1
)
sin_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
0
::
2
],
2
,
axis
=-
1
)
outputs
=
[]
for
tensor
in
tensors
:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
...
...
@@ -501,7 +504,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q
,
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
,
[
q
,
k
]
)
q
,
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
,
q
,
k
)
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
paddlespeech/s2t/modules/encoder.py
浏览文件 @
55870ffb
...
...
@@ -477,9 +477,10 @@ class ConformerEncoder(BaseEncoder):
activation
=
get_activation
(
activation_type
)
# self-attention module definition
encoder_dim
=
output_size
if
pos_enc_layer_type
==
"abs_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
...
...
@@ -495,16 +496,16 @@ class ConformerEncoder(BaseEncoder):
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
positionwise_layer_args
=
(
encoder_dim
,
linear_units
,
dropout_rate
,
activation
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
output_size
,
cnn_module_kernel
,
activation
,
convolution_layer_args
=
(
encoder_dim
,
cnn_module_kernel
,
activation
,
cnn_module_norm
,
causal
)
self
.
encoders
=
nn
.
LayerList
([
ConformerEncoderLayer
(
size
=
output_size
,
size
=
encoder_dim
,
self_attn
=
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward_macaron
=
positionwise_layer
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录