Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
比较版本
94987f26dfafda70ebc3515f0cc2e6d81ba8478a...a1745944657feaaae2fe22201aedd6d42b0a536e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
9 个月 前同步成功
通知
200
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
a1745944657feaaae2fe22201aedd6d42b0a536e
选择Git版本
...
目标分支
94987f26dfafda70ebc3515f0cc2e6d81ba8478a
选择Git版本
比较
Commits (12)
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/03e9ea9e52d61ffd4420bf9bfdc2f422752ad29c
add roformer
2023-07-12T08:59:00+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/55870ffbb3581af4a0b7aed61a168f80f0f045fb
fix bugs
2023-07-12T09:36:13+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/3b6b680771c454151f5ac99013bbc934e967f703
add roformer result
2023-07-12T11:24:48+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b91b1c9b083002fb716c60c29adf3a20e51262e1
support position interpolation for langer attention context windown length.
2023-07-13T03:58:31+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b56fb85ca08e7a24b9fc9f8859d9c5a472b553fa
RoPE with position interpolation
2023-07-14T04:45:24+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/0a5cc5556e602a202304860d5221f9b573582196
rope for streaming decoding
2023-07-14T07:37:55+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/596f71407cd0daa8f1d3e1edd60ce44300d37413
update result
2023-07-17T02:49:07+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/d94db47f784d30b4c8b07c5f2a44c82cc4c7f24f
fix rotary embeding
2023-07-17T03:13:24+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/897dcc37e65fa3260d72045afd58f79741240191
Merge pull request #3407 from zh794390558/roformer
2023-07-20T10:44:00+08:00
Hui Zhang
zhtclz@foxmail.com
Roformer
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/5d10d6e884d0d7eccf4f1724c3ee64bf70a25aaa
Update README.md
2023-07-21T16:51:44+08:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2faa49a39fcd810a2c896f61006c6e9958a5e85c
fix weight decay
2023-07-25T02:38:36+00:00
Hui Zhang
zhtclz@foxmail.com
https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/a1745944657feaaae2fe22201aedd6d42b0a536e
Merge pull request #3424 from zh794390558/fix_opt
2023-07-26T09:55:52+08:00
Hui Zhang
zhtclz@foxmail.com
fix weight decay
隐藏空白更改
内联
并排
Showing
11 changed file
with
552 addition
and
64 deletion
+552
-64
README.md
README.md
+0
-4
examples/aishell/asr1/RESULTS.md
examples/aishell/asr1/RESULTS.md
+37
-9
examples/aishell/asr1/conf/chunk_roformer.yaml
examples/aishell/asr1/conf/chunk_roformer.yaml
+98
-0
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
+98
-0
paddlespeech/dataset/s2t/avg_model.py
paddlespeech/dataset/s2t/avg_model.py
+38
-31
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+2
-1
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+142
-1
paddlespeech/s2t/modules/embedding.py
paddlespeech/s2t/modules/embedding.py
+98
-3
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+35
-10
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+3
-3
paddlespeech/s2t/training/optimizer/__init__.py
paddlespeech/s2t/training/optimizer/__init__.py
+1
-2
未找到文件。
README.md
浏览文件 @
a1745944
...
...
@@ -893,10 +893,6 @@ The Text-to-Speech module is originally called [Parakeet](https://github.com/Pad
-
**[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk): Use PaddleSpeech TTS and ASR to clone voice from videos.**
<div
align=
"center"
>
<img
src=
"https://raw.githubusercontent.com/jerryuhoo/VTuberTalk/main/gui/gui.png"
width =
"500px"
/>
</div>
## Citation
...
...
examples/aishell/asr1/RESULTS.md
浏览文件 @
a1745944
# Aishell
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## RoFormer Streaming
paddle version: 2.5.0
paddlespeech version: 1.5.0
Tesla V100-SXM2-32GB: 1 node, 4 card
Global BachSize: 32
*
4
Training Done: 1 day, 12:56:39.639646
### `decoding.decoding_chunk_size=16`
> chunk_size=16, ((16 - 1) * 4 + 7) * 10ms = (16 * 4 + 3) * 10ms = 670ms
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | 16, -1 | - | 5.63 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 5.44 |
### `decoding.decoding_chunk_size=-1`
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | -1, -1 | - | 5.39 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | -1, -1 | - | 4.99 |
## Conformer Streaming
...
...
@@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
...
...
examples/aishell/asr1/conf/chunk_roformer.yaml
0 → 100644
浏览文件 @
a1745944
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rope_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
transformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
r_num_blocks
:
0
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.0
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
0 → 100644
浏览文件 @
a1745944
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rope_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
bitransformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
3
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
paddlespeech/dataset/s2t/avg_model.py
浏览文件 @
a1745944
...
...
@@ -20,30 +20,6 @@ import numpy as np
import
paddle
def
define_argparse
():
parser
=
argparse
.
ArgumentParser
(
description
=
'average model'
)
parser
.
add_argument
(
'--dst_model'
,
required
=
True
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--ckpt_dir'
,
required
=
True
,
help
=
'ckpt model dir for average'
)
parser
.
add_argument
(
'--val_best'
,
action
=
"store_true"
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--num'
,
default
=
5
,
type
=
int
,
help
=
'nums for averaged model'
)
parser
.
add_argument
(
'--min_epoch'
,
default
=
0
,
type
=
int
,
help
=
'min epoch used for averaging model'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
65536
,
# Big enough
type
=
int
,
help
=
'max epoch used for averaging model'
)
args
=
parser
.
parse_args
()
return
args
def
average_checkpoints
(
dst_model
=
""
,
ckpt_dir
=
""
,
val_best
=
True
,
...
...
@@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print
(
path_list
)
avg
=
None
num
=
args
.
num
num
=
num
assert
num
==
len
(
path_list
)
for
path
in
path_list
:
print
(
f
'Processing
{
path
}
'
)
...
...
@@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if
avg
[
k
]
is
not
None
:
avg
[
k
]
/=
num
paddle
.
save
(
avg
,
args
.
dst_model
)
print
(
f
'Saving to
{
args
.
dst_model
}
'
)
paddle
.
save
(
avg
,
dst_model
)
print
(
f
'Saving to
{
dst_model
}
'
)
meta_path
=
os
.
path
.
splitext
(
args
.
dst_model
)[
0
]
+
'.avg.json'
meta_path
=
os
.
path
.
splitext
(
dst_model
)[
0
]
+
'.avg.json'
with
open
(
meta_path
,
'w'
)
as
f
:
data
=
json
.
dumps
({
"mode"
:
'val_best'
if
args
.
val_best
else
'latest'
,
"avg_ckpt"
:
args
.
dst_model
,
"mode"
:
'val_best'
if
val_best
else
'latest'
,
"avg_ckpt"
:
dst_model
,
"val_loss_mean"
:
avg_val_score
,
"ckpts"
:
path_list
,
"epochs"
:
selected_epochs
.
tolist
(),
...
...
@@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f
.
write
(
data
+
"
\n
"
)
def
define_argparse
():
parser
=
argparse
.
ArgumentParser
(
description
=
'average model'
)
parser
.
add_argument
(
'--dst_model'
,
required
=
True
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--ckpt_dir'
,
required
=
True
,
help
=
'ckpt model dir for average'
)
parser
.
add_argument
(
'--val_best'
,
action
=
"store_true"
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--num'
,
default
=
5
,
type
=
int
,
help
=
'nums for averaged model'
)
parser
.
add_argument
(
'--min_epoch'
,
default
=
0
,
type
=
int
,
help
=
'min epoch used for averaging model'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
65536
,
# Big enough
type
=
int
,
help
=
'max epoch used for averaging model'
)
args
=
parser
.
parse_args
()
print
(
args
)
return
args
def
main
():
args
=
define_argparse
()
average_checkpoints
(
args
)
average_checkpoints
(
dst_model
=
args
.
dst_model
,
ckpt_dir
=
args
.
ckpt_dir
,
val_best
=
args
.
val_best
,
num
=
args
.
num
,
min_epoch
=
args
.
min_epoch
,
max_epoch
=
args
.
max_epoch
)
if
__name__
==
'__main__'
:
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
a1745944
...
...
@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths
)
ctc_time
=
time
.
time
()
-
start
#logger.debug(f"ctc time: {ctc_time}")
if
loss_ctc
is
None
:
loss
=
loss_att
elif
loss_att
is
None
:
...
...
@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
if
decoder_type
==
'transformer'
:
configs
[
'model_conf'
].
pop
(
'reverse_weight'
,
None
)
configs
[
'decoder_conf'
].
pop
(
'r_num_blocks'
,
None
)
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
...
...
paddlespeech/s2t/modules/attention.py
浏览文件 @
a1745944
...
...
@@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition."""
import
math
from
typing
import
List
from
typing
import
Tuple
import
paddle
...
...
@@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
]
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
,
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
...
...
@@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
"""
# (B,T,D) -> (B,T,H,D/H)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# when export onnx model, for 1st chunk, we feed
...
...
@@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
class
RoPERelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
adaptive_scale
=
False
,
init_weights
=
False
):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
def
align
(
self
,
tensor
:
paddle
.
Tensor
,
axes
:
List
[
int
],
ndim
=
None
):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert
len
(
axes
)
==
tensor
.
dim
()
assert
ndim
or
min
(
axes
)
>=
0
ndim
=
ndim
or
max
(
axes
)
+
1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices
=
[
None
]
*
ndim
for
i
in
axes
:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices
[
i
]
=
slice
(
None
)
return
tensor
[
indices
]
def
apply_rotary_position_embeddings
(
self
,
sinusoidal
,
*
tensors
):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert
len
(
tensors
)
>
0
,
'at least one input tensor'
assert
all
(
[
tensor
.
shape
==
tensors
[
0
].
shape
for
tensor
in
tensors
[
1
:]]),
'all tensors must have the same shape'
# (B,H,T,D)
ndim
=
tensors
[
0
].
dim
()
_
,
H
,
T
,
D
=
tensors
[
0
].
shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal
=
sinusoidal
.
reshape
((
1
,
T
,
H
,
D
)).
transpose
([
0
,
2
,
1
,
3
])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
1
::
2
],
2
,
axis
=-
1
)
sin_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
0
::
2
],
2
,
axis
=-
1
)
outputs
=
[]
for
tensor
in
tensors
:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2
=
paddle
.
stack
([
-
tensor
[...,
1
::
2
],
tensor
[...,
::
2
]],
ndim
)
tensor2
=
paddle
.
reshape
(
tensor2
,
paddle
.
shape
(
tensor
))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs
.
append
(
tensor
*
cos_pos
+
tensor2
*
sin_pos
)
return
outputs
[
0
]
if
len
(
outputs
)
==
1
else
outputs
def
forward
(
self
,
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t
=
q
.
shape
[
2
]
q
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
q
)
# k will increase when in streaming decoding.
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
k
)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
shape
[
0
]
>
0
:
# last dim `d_k * 2` for (key, val)
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
paddlespeech/s2t/modules/embedding.py
浏览文件 @
a1745944
...
...
@@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False.
"""
nn
.
Layer
.
__init__
(
self
)
self
.
d_model
=
d_model
self
.
d_model
=
paddle
.
to_tensor
(
d_model
)
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
base
=
paddle
.
to_tensor
(
10000.0
)
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
-
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
(
paddle
.
log
(
self
.
base
)
/
self
.
d_model
))
# [B,T,D]
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
...
...
@@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding):
assert
offset
+
x
.
shape
[
1
]
<
self
.
max_len
,
"offset: {} + x.shape[1]: {} is larger than the max_len: {}"
.
format
(
offset
,
x
.
shape
[
1
],
self
.
max_len
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class
ScaledRotaryRelPositionalEncoding
(
RelPositionalEncoding
):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
,
scale
=
1
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
self
.
pscale
=
paddle
.
to_tensor
(
scale
)
self
.
max_len
=
max_len
*
scale
def
sinusoidal_embeddings
(
self
,
pos
:
paddle
.
Tensor
,
dim
:
paddle
.
Tensor
,
base
=
10000
)
->
paddle
.
Tensor
:
"""计算pos位置的dim维sinusoidal编码"""
assert
dim
%
2
==
0
# (d/2,)
indices
=
paddle
.
arange
(
0
,
dim
//
2
,
dtype
=
pos
.
dtype
)
indices
=
paddle
.
pow
(
paddle
.
cast
(
base
,
pos
.
dtype
),
-
2
*
indices
/
dim
)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings
=
paddle
.
einsum
(
'...,d->...d'
,
pos
,
indices
)
# (1, T, d/2, 2)
embeddings
=
paddle
.
stack
(
[
paddle
.
sin
(
embeddings
),
paddle
.
cos
(
embeddings
)],
axis
=-
1
)
# (1, T, d)
embeddings
=
paddle
.
flatten
(
embeddings
,
start_axis
=-
2
,
stop_axis
=-
1
)
return
embeddings
def
forward
(
self
,
x
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x
=
x
*
self
.
xscale
B
,
T
,
D
=
x
.
shape
assert
D
==
self
.
d_model
# postion interploation
start
=
0
end
=
T
*
self
.
pscale
assert
end
<=
self
.
max_len
position
=
paddle
.
arange
(
start
,
end
,
dtype
=
x
.
dtype
).
unsqueeze
(
0
)
position
*=
1.0
/
self
.
pscale
pe
=
self
.
sinusoidal_embeddings
(
position
,
self
.
d_model
,
base
=
self
.
base
)
pos_emb
=
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
def
position_encoding
(
self
,
offset
:
int
,
size
:
int
)
->
paddle
.
Tensor
:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start
=
offset
end
=
(
offset
+
size
)
*
self
.
pscale
assert
end
<=
self
.
max_len
position
=
paddle
.
arange
(
start
,
end
,
dtype
=
paddle
.
get_default_dtype
()).
unsqueeze
(
0
)
position
*=
1.0
/
self
.
pscale
pe
=
self
.
sinusoidal_embeddings
(
position
,
self
.
d_model
,
base
=
self
.
base
)
return
self
.
dropout
(
pe
)
paddlespeech/s2t/modules/encoder.py
浏览文件 @
a1745944
...
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RoPERelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
...
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"rope_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
...
...
@@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer):
xs
=
self
.
global_cmvn
(
xs
)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
xs
,
_
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers
,
_
,
cache_t1
,
_
=
att_cache
.
shape
chunk_size
=
xs
.
shape
[
1
]
attention_key_size
=
cache_t1
+
chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention`
and `RoPERelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
...
...
@@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder):
activation
=
get_activation
(
activation_type
)
# self-attention module definition
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
encoder_dim
=
output_size
if
pos_enc_layer_type
==
"abs_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
positionwise_layer_args
=
(
encoder_dim
,
linear_units
,
dropout_rate
,
activation
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
output_size
,
cnn_module_kernel
,
activation
,
convolution_layer_args
=
(
encoder_dim
,
cnn_module_kernel
,
activation
,
cnn_module_norm
,
causal
)
self
.
encoders
=
nn
.
LayerList
([
ConformerEncoderLayer
(
size
=
output_size
,
size
=
encoder_dim
,
self_attn
=
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward_macaron
=
positionwise_layer
(
...
...
@@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer):
activation
=
get_activation
(
activation_type
)
# self-attention module definition
if
pos_enc_layer_type
!=
"rel
_pos"
:
if
pos_enc_layer_type
==
"abs
_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
el
se
:
el
if
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
a1745944
...
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
...
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
paddlespeech/s2t/training/optimizer/__init__.py
浏览文件 @
a1745944
...
...
@@ -102,8 +102,7 @@ class OptimizerFactory():
grad_clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
args
[
'grad_clip'
])
if
"grad_clip"
in
args
else
None
weight_decay
=
L2Decay
(
args
[
'weight_decay'
])
if
"weight_decay"
in
args
else
None
weight_decay
=
args
.
get
(
"weight_decay"
,
None
)
if
weight_decay
:
logger
.
info
(
f
'<WeightDecay -
{
weight_decay
}
>'
)
if
grad_clip
:
...
...