Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
897dcc37
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
897dcc37
编写于
7月 20, 2023
作者:
H
Hui Zhang
提交者:
GitHub
7月 20, 2023
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3407 from zh794390558/roformer
Roformer
上级
94987f26
d94db47f
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
551 addition
and
58 deletion
+551
-58
examples/aishell/asr1/RESULTS.md
examples/aishell/asr1/RESULTS.md
+37
-9
examples/aishell/asr1/conf/chunk_roformer.yaml
examples/aishell/asr1/conf/chunk_roformer.yaml
+98
-0
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
+98
-0
paddlespeech/dataset/s2t/avg_model.py
paddlespeech/dataset/s2t/avg_model.py
+38
-31
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+2
-1
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+142
-1
paddlespeech/s2t/modules/embedding.py
paddlespeech/s2t/modules/embedding.py
+98
-3
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+35
-10
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+3
-3
未找到文件。
examples/aishell/asr1/RESULTS.md
浏览文件 @
897dcc37
# Aishell
# Aishell
## Conformer
## RoFormer Streaming
paddle version: 2.2.2
paddle version: 2.5.0
paddlespeech version: 1.0.1
paddlespeech version: 1.5.0
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
Tesla V100-SXM2-32GB: 1 node, 4 card
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
Global BachSize: 32
*
4
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
Training Done: 1 day, 12:56:39.639646
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 |
### `decoding.decoding_chunk_size=16`
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
> chunk_size=16, ((16 - 1) * 4 + 7) * 10ms = (16 * 4 + 3) * 10ms = 670ms
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | 16, -1 | - | 5.63 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 5.44 |
### `decoding.decoding_chunk_size=-1`
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | -1, -1 | - | 5.39 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | -1, -1 | - | 4.99 |
## Conformer Streaming
## Conformer Streaming
...
@@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
...
@@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Transformer
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
...
...
examples/aishell/asr1/conf/chunk_roformer.yaml
0 → 100644
浏览文件 @
897dcc37
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rope_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
transformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
r_num_blocks
:
0
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.0
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
0 → 100644
浏览文件 @
897dcc37
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rope_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
bitransformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
3
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
paddlespeech/dataset/s2t/avg_model.py
浏览文件 @
897dcc37
...
@@ -20,30 +20,6 @@ import numpy as np
...
@@ -20,30 +20,6 @@ import numpy as np
import
paddle
import
paddle
def
define_argparse
():
parser
=
argparse
.
ArgumentParser
(
description
=
'average model'
)
parser
.
add_argument
(
'--dst_model'
,
required
=
True
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--ckpt_dir'
,
required
=
True
,
help
=
'ckpt model dir for average'
)
parser
.
add_argument
(
'--val_best'
,
action
=
"store_true"
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--num'
,
default
=
5
,
type
=
int
,
help
=
'nums for averaged model'
)
parser
.
add_argument
(
'--min_epoch'
,
default
=
0
,
type
=
int
,
help
=
'min epoch used for averaging model'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
65536
,
# Big enough
type
=
int
,
help
=
'max epoch used for averaging model'
)
args
=
parser
.
parse_args
()
return
args
def
average_checkpoints
(
dst_model
=
""
,
def
average_checkpoints
(
dst_model
=
""
,
ckpt_dir
=
""
,
ckpt_dir
=
""
,
val_best
=
True
,
val_best
=
True
,
...
@@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
...
@@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print
(
path_list
)
print
(
path_list
)
avg
=
None
avg
=
None
num
=
args
.
num
num
=
num
assert
num
==
len
(
path_list
)
assert
num
==
len
(
path_list
)
for
path
in
path_list
:
for
path
in
path_list
:
print
(
f
'Processing
{
path
}
'
)
print
(
f
'Processing
{
path
}
'
)
...
@@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
...
@@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if
avg
[
k
]
is
not
None
:
if
avg
[
k
]
is
not
None
:
avg
[
k
]
/=
num
avg
[
k
]
/=
num
paddle
.
save
(
avg
,
args
.
dst_model
)
paddle
.
save
(
avg
,
dst_model
)
print
(
f
'Saving to
{
args
.
dst_model
}
'
)
print
(
f
'Saving to
{
dst_model
}
'
)
meta_path
=
os
.
path
.
splitext
(
args
.
dst_model
)[
0
]
+
'.avg.json'
meta_path
=
os
.
path
.
splitext
(
dst_model
)[
0
]
+
'.avg.json'
with
open
(
meta_path
,
'w'
)
as
f
:
with
open
(
meta_path
,
'w'
)
as
f
:
data
=
json
.
dumps
({
data
=
json
.
dumps
({
"mode"
:
'val_best'
if
args
.
val_best
else
'latest'
,
"mode"
:
'val_best'
if
val_best
else
'latest'
,
"avg_ckpt"
:
args
.
dst_model
,
"avg_ckpt"
:
dst_model
,
"val_loss_mean"
:
avg_val_score
,
"val_loss_mean"
:
avg_val_score
,
"ckpts"
:
path_list
,
"ckpts"
:
path_list
,
"epochs"
:
selected_epochs
.
tolist
(),
"epochs"
:
selected_epochs
.
tolist
(),
...
@@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
...
@@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f
.
write
(
data
+
"
\n
"
)
f
.
write
(
data
+
"
\n
"
)
def
define_argparse
():
parser
=
argparse
.
ArgumentParser
(
description
=
'average model'
)
parser
.
add_argument
(
'--dst_model'
,
required
=
True
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--ckpt_dir'
,
required
=
True
,
help
=
'ckpt model dir for average'
)
parser
.
add_argument
(
'--val_best'
,
action
=
"store_true"
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--num'
,
default
=
5
,
type
=
int
,
help
=
'nums for averaged model'
)
parser
.
add_argument
(
'--min_epoch'
,
default
=
0
,
type
=
int
,
help
=
'min epoch used for averaging model'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
65536
,
# Big enough
type
=
int
,
help
=
'max epoch used for averaging model'
)
args
=
parser
.
parse_args
()
print
(
args
)
return
args
def
main
():
def
main
():
args
=
define_argparse
()
args
=
define_argparse
()
average_checkpoints
(
args
)
average_checkpoints
(
dst_model
=
args
.
dst_model
,
ckpt_dir
=
args
.
ckpt_dir
,
val_best
=
args
.
val_best
,
num
=
args
.
num
,
min_epoch
=
args
.
min_epoch
,
max_epoch
=
args
.
max_epoch
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
897dcc37
...
@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths
)
text_lengths
)
ctc_time
=
time
.
time
()
-
start
ctc_time
=
time
.
time
()
-
start
#logger.debug(f"ctc time: {ctc_time}")
#logger.debug(f"ctc time: {ctc_time}")
if
loss_ctc
is
None
:
if
loss_ctc
is
None
:
loss
=
loss_att
loss
=
loss_att
elif
loss_att
is
None
:
elif
loss_att
is
None
:
...
@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
...
@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
decoder_type
=
configs
.
get
(
'decoder'
,
'transformer'
)
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
logger
.
debug
(
f
"U2 Decoder type:
{
decoder_type
}
"
)
if
decoder_type
==
'transformer'
:
if
decoder_type
==
'transformer'
:
configs
[
'model_conf'
].
pop
(
'reverse_weight'
,
None
)
configs
[
'decoder_conf'
].
pop
(
'r_num_blocks'
,
None
)
decoder
=
TransformerDecoder
(
vocab_size
,
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
**
configs
[
'decoder_conf'
])
...
...
paddlespeech/s2t/modules/attention.py
浏览文件 @
897dcc37
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition."""
"""Multi-Head Attention layer definition."""
import
math
import
math
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Tuple
import
paddle
import
paddle
...
@@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log
...
@@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
]
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
,
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
# https://www.jianshu.com/p/c0608efcc26f
...
@@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer):
...
@@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
and `head * d_k == size`
"""
"""
# (B,T,D) -> (B,T,H,D/H)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# when export onnx model, for 1st chunk, we feed
# when export onnx model, for 1st chunk, we feed
...
@@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
...
@@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self
.
d_k
)
# (batch, head, time1, time2)
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
class
RoPERelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
adaptive_scale
=
False
,
init_weights
=
False
):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
def
align
(
self
,
tensor
:
paddle
.
Tensor
,
axes
:
List
[
int
],
ndim
=
None
):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert
len
(
axes
)
==
tensor
.
dim
()
assert
ndim
or
min
(
axes
)
>=
0
ndim
=
ndim
or
max
(
axes
)
+
1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices
=
[
None
]
*
ndim
for
i
in
axes
:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices
[
i
]
=
slice
(
None
)
return
tensor
[
indices
]
def
apply_rotary_position_embeddings
(
self
,
sinusoidal
,
*
tensors
):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert
len
(
tensors
)
>
0
,
'at least one input tensor'
assert
all
(
[
tensor
.
shape
==
tensors
[
0
].
shape
for
tensor
in
tensors
[
1
:]]),
'all tensors must have the same shape'
# (B,H,T,D)
ndim
=
tensors
[
0
].
dim
()
_
,
H
,
T
,
D
=
tensors
[
0
].
shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal
=
sinusoidal
.
reshape
((
1
,
T
,
H
,
D
)).
transpose
([
0
,
2
,
1
,
3
])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
1
::
2
],
2
,
axis
=-
1
)
sin_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
0
::
2
],
2
,
axis
=-
1
)
outputs
=
[]
for
tensor
in
tensors
:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2
=
paddle
.
stack
([
-
tensor
[...,
1
::
2
],
tensor
[...,
::
2
]],
ndim
)
tensor2
=
paddle
.
reshape
(
tensor2
,
paddle
.
shape
(
tensor
))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs
.
append
(
tensor
*
cos_pos
+
tensor2
*
sin_pos
)
return
outputs
[
0
]
if
len
(
outputs
)
==
1
else
outputs
def
forward
(
self
,
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t
=
q
.
shape
[
2
]
q
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
q
)
# k will increase when in streaming decoding.
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
k
)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
shape
[
0
]
>
0
:
# last dim `d_k * 2` for (key, val)
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
paddlespeech/s2t/modules/embedding.py
浏览文件 @
897dcc37
...
@@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
...
@@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False.
reverse (bool, optional): Not used. Defaults to False.
"""
"""
nn
.
Layer
.
__init__
(
self
)
nn
.
Layer
.
__init__
(
self
)
self
.
d_model
=
d_model
self
.
d_model
=
paddle
.
to_tensor
(
d_model
)
self
.
max_len
=
max_len
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
base
=
paddle
.
to_tensor
(
10000.0
)
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
position
=
paddle
.
arange
(
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term
=
paddle
.
exp
(
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
(
paddle
.
log
(
self
.
base
)
/
self
.
d_model
))
# [B,T,D]
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
...
@@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding):
...
@@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding):
assert
offset
+
x
.
shape
[
assert
offset
+
x
.
shape
[
1
]
<
self
.
max_len
,
"offset: {} + x.shape[1]: {} is larger than the max_len: {}"
.
format
(
1
]
<
self
.
max_len
,
"offset: {} + x.shape[1]: {} is larger than the max_len: {}"
.
format
(
offset
,
x
.
shape
[
1
],
self
.
max_len
)
offset
,
x
.
shape
[
1
],
self
.
max_len
)
x
=
x
*
self
.
xscale
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]]
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class
ScaledRotaryRelPositionalEncoding
(
RelPositionalEncoding
):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
,
scale
=
1
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
self
.
pscale
=
paddle
.
to_tensor
(
scale
)
self
.
max_len
=
max_len
*
scale
def
sinusoidal_embeddings
(
self
,
pos
:
paddle
.
Tensor
,
dim
:
paddle
.
Tensor
,
base
=
10000
)
->
paddle
.
Tensor
:
"""计算pos位置的dim维sinusoidal编码"""
assert
dim
%
2
==
0
# (d/2,)
indices
=
paddle
.
arange
(
0
,
dim
//
2
,
dtype
=
pos
.
dtype
)
indices
=
paddle
.
pow
(
paddle
.
cast
(
base
,
pos
.
dtype
),
-
2
*
indices
/
dim
)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings
=
paddle
.
einsum
(
'...,d->...d'
,
pos
,
indices
)
# (1, T, d/2, 2)
embeddings
=
paddle
.
stack
(
[
paddle
.
sin
(
embeddings
),
paddle
.
cos
(
embeddings
)],
axis
=-
1
)
# (1, T, d)
embeddings
=
paddle
.
flatten
(
embeddings
,
start_axis
=-
2
,
stop_axis
=-
1
)
return
embeddings
def
forward
(
self
,
x
:
paddle
.
Tensor
,
offset
:
int
=
0
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x
=
x
*
self
.
xscale
B
,
T
,
D
=
x
.
shape
assert
D
==
self
.
d_model
# postion interploation
start
=
0
end
=
T
*
self
.
pscale
assert
end
<=
self
.
max_len
position
=
paddle
.
arange
(
start
,
end
,
dtype
=
x
.
dtype
).
unsqueeze
(
0
)
position
*=
1.0
/
self
.
pscale
pe
=
self
.
sinusoidal_embeddings
(
position
,
self
.
d_model
,
base
=
self
.
base
)
pos_emb
=
pe
[:,
offset
:
offset
+
x
.
shape
[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
def
position_encoding
(
self
,
offset
:
int
,
size
:
int
)
->
paddle
.
Tensor
:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start
=
offset
end
=
(
offset
+
size
)
*
self
.
pscale
assert
end
<=
self
.
max_len
position
=
paddle
.
arange
(
start
,
end
,
dtype
=
paddle
.
get_default_dtype
()).
unsqueeze
(
0
)
position
*=
1.0
/
self
.
pscale
pe
=
self
.
sinusoidal_embeddings
(
position
,
self
.
d_model
,
base
=
self
.
base
)
return
self
.
dropout
(
pe
)
paddlespeech/s2t/modules/encoder.py
浏览文件 @
897dcc37
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RoPERelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"rope_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
pos_enc_class
=
NoPositionalEncoding
else
:
else
:
...
@@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer):
...
@@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer):
xs
=
self
.
global_cmvn
(
xs
)
xs
=
self
.
global_cmvn
(
xs
)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
xs
,
_
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
# after embed, xs=(B=1, chunk_size, hidden-dim)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers
,
_
,
cache_t1
,
_
=
att_cache
.
shape
elayers
,
_
,
cache_t1
,
_
=
att_cache
.
shape
chunk_size
=
xs
.
shape
[
1
]
chunk_size
=
xs
.
shape
[
1
]
attention_key_size
=
cache_t1
+
chunk_size
attention_key_size
=
cache_t1
+
chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention`
and `RoPERelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
...
@@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder):
...
@@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder):
activation
=
get_activation
(
activation_type
)
activation
=
get_activation
(
activation_type
)
# self-attention module definition
# self-attention module definition
encoder_dim
=
output_size
if
pos_enc_layer_type
==
"abs_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
attention_dropout_rate
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
positionwise_layer_args
=
(
encoder_dim
,
linear_units
,
dropout_rate
,
activation
)
activation
)
# convolution module definition
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
output_size
,
cnn_module_kernel
,
activation
,
convolution_layer_args
=
(
encoder_dim
,
cnn_module_kernel
,
activation
,
cnn_module_norm
,
causal
)
cnn_module_norm
,
causal
)
self
.
encoders
=
nn
.
LayerList
([
self
.
encoders
=
nn
.
LayerList
([
ConformerEncoderLayer
(
ConformerEncoderLayer
(
size
=
output_size
,
size
=
encoder_dim
,
self_attn
=
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
self_attn
=
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward
=
positionwise_layer
(
*
positionwise_layer_args
),
feed_forward_macaron
=
positionwise_layer
(
feed_forward_macaron
=
positionwise_layer
(
...
@@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer):
...
@@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer):
activation
=
get_activation
(
activation_type
)
activation
=
get_activation
(
activation_type
)
# self-attention module definition
# self-attention module definition
if
pos_enc_layer_type
!=
"rel
_pos"
:
if
pos_enc_layer_type
==
"abs
_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
attention_dropout_rate
)
el
se
:
el
if
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
adaptive_scale
,
init_weights
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer
=
PositionwiseFeedForward
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
897dcc37
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
`PositionwiseFeedForward`, instance can be used as the argument.
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
`PositionwiseFeedForward` instance can be used as the argument.
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录