Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
03e9ea9e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03e9ea9e
编写于
7月 12, 2023
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add roformer
上级
94987f26
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
367 addition
and
12 deletion
+367
-12
examples/aishell/asr1/conf/chunk_roformer.yaml
examples/aishell/asr1/conf/chunk_roformer.yaml
+98
-0
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
+98
-0
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+133
-1
paddlespeech/s2t/modules/embedding.py
paddlespeech/s2t/modules/embedding.py
+5
-2
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+30
-6
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+3
-3
未找到文件。
examples/aishell/asr1/conf/chunk_roformer.yaml
0 → 100644
浏览文件 @
03e9ea9e
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rpoe_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
transformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
6
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
0 → 100644
浏览文件 @
03e9ea9e
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
encoder_conf
:
output_size
:
256
# dimension of attention
attention_heads
:
4
linear_units
:
2048
# the number of units of position-wise feed forward
num_blocks
:
12
# the number of encoder blocks
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rpoe_pos'
# abs_pos, rel_pos, rope_pos
selfattention_layer_type
:
'
rel_selfattn'
# unused
causal
:
true
use_dynamic_chunk
:
true
cnn_module_norm
:
'
layer_norm'
# using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk
:
false
# decoder related
decoder
:
bitransformer
# transformer, bitransformer
decoder_conf
:
attention_heads
:
4
linear_units
:
2048
num_blocks
:
3
r_num_blocks
:
3
# only for bitransformer
dropout_rate
:
0.1
# sublayer output dropout
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
reverse_weight
:
0.3
# only for bitransformer
length_normalized_loss
:
false
init_type
:
'
kaiming_uniform'
# !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
spm_model_prefix
:
'
'
unit_type
:
'
char'
preprocess_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
###########################################
# Training #
###########################################
n_epoch
:
240
accum_grad
:
1
global_grad_clip
:
5.0
dist_sampler
:
True
optim
:
adam
optim_conf
:
lr
:
0.001
weight_decay
:
1.0e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
paddlespeech/s2t/modules/attention.py
浏览文件 @
03e9ea9e
...
@@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
...
@@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
]
__all__
=
[
"MultiHeadedAttention"
,
"RelPositionMultiHeadedAttention"
,
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
# https://www.jianshu.com/p/c0608efcc26f
...
@@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer):
...
@@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
and `head * d_k == size`
"""
"""
# (B,T,D) -> (B,T,H,D/H)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# when export onnx model, for 1st chunk, we feed
# when export onnx model, for 1st chunk, we feed
...
@@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
...
@@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self
.
d_k
)
# (batch, head, time1, time2)
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
class
RoPERelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
adaptive_scale
=
False
,
init_weights
=
False
):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
def
align
(
self
,
tensor
:
paddle
.
Tensor
,
axes
:
List
[
int
],
ndim
=
None
):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert
len
(
axes
)
==
tensor
.
dim
()
assert
ndim
or
min
(
axes
)
>=
0
ndim
=
ndim
or
max
(
axes
)
+
1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices
=
[
None
]
*
ndim
for
i
in
axes
:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices
[
i
]
=
slice
(
None
)
return
tensor
[
indices
]
def
apply_rotary_position_embeddings
(
self
,
sinusoidal
,
*
tensors
):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
"""
assert
len
(
tensors
)
>
0
,
'at least one input tensor'
assert
all
(
[
tensor
.
shape
==
tensors
[
0
].
shape
for
tensor
in
tensors
[
1
:]]),
'all tensors must have the same shape'
ndim
=
tensors
[
0
].
dim
()
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal
=
self
.
align
(
sinusoidal
,
[
0
,
1
,
-
1
],
ndim
)
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
1
::
2
],
2
,
axis
=-
1
)
sin_pos
=
paddle
.
repeat_interleave
(
sinusoidal
[...,
0
::
2
],
2
,
axis
=-
1
)
outputs
=
[]
for
tensor
in
tensors
:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2
=
paddle
.
stack
([
-
tensor
[...,
1
::
2
],
tensor
[...,
::
2
]],
ndim
)
tensor2
=
paddle
.
reshape
(
tensor2
,
paddle
.
shape
(
tensor
))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs
.
append
(
tensor
*
cos_pos
+
tensor2
*
sin_pos
)
return
outputs
[
0
]
if
len
(
outputs
)
==
1
else
outputs
def
forward
(
self
,
query
:
paddle
.
Tensor
,
key
:
paddle
.
Tensor
,
value
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
=
paddle
.
ones
([
0
,
0
,
0
],
dtype
=
paddle
.
bool
),
pos_emb
:
paddle
.
Tensor
=
paddle
.
empty
([
0
]),
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
shape
[
0
]
>
0
:
# last dim `d_k * 2` for (key, val)
key_cache
,
value_cache
=
paddle
.
split
(
cache
,
2
,
axis
=-
1
)
k
=
paddle
.
concat
([
key_cache
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
value_cache
,
v
],
axis
=
2
)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q
,
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
,
[
q
,
k
])
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
paddlespeech/s2t/modules/embedding.py
浏览文件 @
03e9ea9e
...
@@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
...
@@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self
.
max_len
=
max_len
self
.
max_len
=
max_len
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
xscale
=
paddle
.
to_tensor
(
math
.
sqrt
(
self
.
d_model
))
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
base
=
10000.0
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
self
.
pe
=
paddle
.
zeros
([
1
,
self
.
max_len
,
self
.
d_model
])
#[B=1,T,D]
position
=
paddle
.
arange
(
position
=
paddle
.
arange
(
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
0
,
self
.
max_len
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
#[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term
=
paddle
.
exp
(
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
(
math
.
log
(
self
.
base
)
/
self
.
d_model
))
# [B,T,D]
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
self
.
pe
[:,
:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
03e9ea9e
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
...
@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RoPERelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
NoPositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
...
@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"rope_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
pos_enc_class
=
NoPositionalEncoding
else
:
else
:
...
@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
...
@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
chunk_size
=
xs
.
shape
[
1
]
chunk_size
=
xs
.
shape
[
1
]
attention_key_size
=
cache_t1
+
chunk_size
attention_key_size
=
cache_t1
+
chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention`
and `RoPERelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
...
@@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder):
...
@@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder):
activation
=
get_activation
(
activation_type
)
activation
=
get_activation
(
activation_type
)
# self-attention module definition
# self-attention module definition
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
if
pos_enc_layer_type
==
"abs_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
...
@@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer):
...
@@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer):
activation
=
get_activation
(
activation_type
)
activation
=
get_activation
(
activation_type
)
# self-attention module definition
# self-attention module definition
if
pos_enc_layer_type
!=
"rel
_pos"
:
if
pos_enc_layer_type
==
"abs
_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
attention_dropout_rate
)
el
se
:
el
if
pos_enc_layer_type
==
"rel_pos"
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
adaptive_scale
,
init_weights
)
elif
pos_enc_layer_type
==
"rope_pos"
:
encoder_selfattn_layer
=
RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
adaptive_scale
,
init_weights
)
else
:
raise
ValueError
(
f
"pos_enc_layer_type
{
pos_enc_layer_type
}
not supported."
)
# feed-forward module definition
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer
=
PositionwiseFeedForward
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
03e9ea9e
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
...
@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
`PositionwiseFeedForward`, instance can be used as the argument.
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
...
@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
`PositionwiseFeedForward` instance can be used as the argument.
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
...
@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
Args:
size (int): Input dimension.
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention`
or `
RelPositionMultiHeadedAttention`
`MultiHeadedAttention`
, `RelPositionMultiHeadedAttention` or `RoPE
RelPositionMultiHeadedAttention`
instance can be used as the argument.
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
`PositionwiseFeedForward` instance can be used as the argument.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录