Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d53e1163
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d53e1163
编写于
3月 22, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the code, test=asr
上级
ab16d8ce
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
196 addition
and
324 deletion
+196
-324
paddlespeech/s2t/__init__.py
paddlespeech/s2t/__init__.py
+0
-6
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+2
-2
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+4
-6
paddlespeech/s2t/modules/activation.py
paddlespeech/s2t/modules/activation.py
+7
-6
paddlespeech/s2t/modules/align.py
paddlespeech/s2t/modules/align.py
+74
-0
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+6
-5
paddlespeech/s2t/modules/conformer_convolution.py
paddlespeech/s2t/modules/conformer_convolution.py
+9
-16
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+2
-1
paddlespeech/s2t/modules/decoder.py
paddlespeech/s2t/modules/decoder.py
+6
-13
paddlespeech/s2t/modules/decoder_layer.py
paddlespeech/s2t/modules/decoder_layer.py
+7
-23
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+3
-7
paddlespeech/s2t/modules/encoder_layer.py
paddlespeech/s2t/modules/encoder_layer.py
+13
-39
paddlespeech/s2t/modules/initializer.py
paddlespeech/s2t/modules/initializer.py
+44
-141
paddlespeech/s2t/modules/nets_utils.py
paddlespeech/s2t/modules/nets_utils.py
+0
-44
paddlespeech/s2t/modules/positionwise_feed_forward.py
paddlespeech/s2t/modules/positionwise_feed_forward.py
+3
-2
paddlespeech/s2t/modules/subsampling.py
paddlespeech/s2t/modules/subsampling.py
+16
-13
未找到文件。
paddlespeech/s2t/__init__.py
浏览文件 @
d53e1163
...
@@ -21,7 +21,6 @@ from paddle import nn
...
@@ -21,7 +21,6 @@ from paddle import nn
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddlespeech.s2t.modules
import
initializer
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
#TODO(Hui Zhang): remove fluid import
#TODO(Hui Zhang): remove fluid import
...
@@ -506,8 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
...
@@ -506,8 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger
.
debug
(
logger
.
debug
(
"register user LayerDict to paddle.nn, remove this when fixed!"
)
"register user LayerDict to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
"""
hack KaiminigUniform: change limit from np.sqrt(6.0 / float(fan_in)) to np.sqrt(1.0 / float(fan_in))
"""
paddle
.
nn
.
initializer
.
KaimingUniform
=
initializer
.
KaimingUniform
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
d53e1163
...
@@ -239,7 +239,7 @@ class U2Trainer(Trainer):
...
@@ -239,7 +239,7 @@ class U2Trainer(Trainer):
n_iter_processes
=
config
.
num_workers
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
,
num_encs
=
1
,
dist_sampler
=
Fals
e
,
dist_sampler
=
Tru
e
,
shortest_first
=
False
)
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
self
.
valid_loader
=
BatchDataLoader
(
...
@@ -260,7 +260,7 @@ class U2Trainer(Trainer):
...
@@ -260,7 +260,7 @@ class U2Trainer(Trainer):
n_iter_processes
=
config
.
num_workers
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
,
num_encs
=
1
,
dist_sampler
=
Fals
e
,
dist_sampler
=
Tru
e
,
shortest_first
=
False
)
shortest_first
=
False
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
d53e1163
...
@@ -41,7 +41,6 @@ from paddlespeech.s2t.modules.mask import make_pad_mask
...
@@ -41,7 +41,6 @@ from paddlespeech.s2t.modules.mask import make_pad_mask
from
paddlespeech.s2t.modules.mask
import
mask_finished_preds
from
paddlespeech.s2t.modules.mask
import
mask_finished_preds
from
paddlespeech.s2t.modules.mask
import
mask_finished_scores
from
paddlespeech.s2t.modules.mask
import
mask_finished_scores
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.modules.mask
import
subsequent_mask
from
paddlespeech.s2t.modules.nets_utils
import
initialize
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
...
@@ -51,6 +50,8 @@ from paddlespeech.s2t.utils.tensor_utils import pad_sequence
...
@@ -51,6 +50,8 @@ from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from
paddlespeech.s2t.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.modules.initializer
import
DefaultInitializerContext
# from paddlespeech.s2t.modules.initializer import initialize
__all__
=
[
"U2Model"
,
"U2InferModel"
]
__all__
=
[
"U2Model"
,
"U2InferModel"
]
...
@@ -784,11 +785,8 @@ class U2Model(U2DecodeModel):
...
@@ -784,11 +785,8 @@ class U2Model(U2DecodeModel):
def
__init__
(
self
,
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
model_conf
=
configs
.
get
(
'model_conf'
,
dict
())
model_conf
=
configs
.
get
(
'model_conf'
,
dict
())
init_type
=
model_conf
.
get
(
"init_type"
,
None
)
init_type
=
model_conf
.
get
(
"init_type"
,
None
)
if
init_type
is
not
None
:
with
DefaultInitializerContext
(
init_type
):
logger
.
info
(
f
"Use
{
init_type
}
initializer as default initializer"
)
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
initialize
(
self
,
init_type
)
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
nn
.
initializer
.
set_global_initializer
(
None
)
super
().
__init__
(
super
().
__init__
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
...
...
paddlespeech/s2t/modules/activation.py
浏览文件 @
d53e1163
...
@@ -16,7 +16,8 @@ from collections import OrderedDict
...
@@ -16,7 +16,8 @@ from collections import OrderedDict
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.align
import
Conv2D
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -51,7 +52,7 @@ class LinearGLUBlock(nn.Layer):
...
@@ -51,7 +52,7 @@ class LinearGLUBlock(nn.Layer):
idim (int): input and output dimension
idim (int): input and output dimension
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
idim
,
idim
*
2
)
self
.
fc
=
Linear
(
idim
,
idim
*
2
)
def
forward
(
self
,
xs
):
def
forward
(
self
,
xs
):
return
glu
(
self
.
fc
(
xs
),
dim
=-
1
)
return
glu
(
self
.
fc
(
xs
),
dim
=-
1
)
...
@@ -75,7 +76,7 @@ class ConvGLUBlock(nn.Layer):
...
@@ -75,7 +76,7 @@ class ConvGLUBlock(nn.Layer):
self
.
conv_residual
=
None
self
.
conv_residual
=
None
if
in_ch
!=
out_ch
:
if
in_ch
!=
out_ch
:
self
.
conv_residual
=
nn
.
utils
.
weight_norm
(
self
.
conv_residual
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
Conv2D
(
in_channels
=
in_ch
,
out_channels
=
out_ch
,
kernel_size
=
(
1
,
1
)),
in_channels
=
in_ch
,
out_channels
=
out_ch
,
kernel_size
=
(
1
,
1
)),
name
=
'weight'
,
name
=
'weight'
,
dim
=
0
)
dim
=
0
)
...
@@ -86,7 +87,7 @@ class ConvGLUBlock(nn.Layer):
...
@@ -86,7 +87,7 @@ class ConvGLUBlock(nn.Layer):
layers
=
OrderedDict
()
layers
=
OrderedDict
()
if
bottlececk_dim
==
0
:
if
bottlececk_dim
==
0
:
layers
[
'conv'
]
=
nn
.
utils
.
weight_norm
(
layers
[
'conv'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
Conv2D
(
in_channels
=
in_ch
,
in_channels
=
in_ch
,
out_channels
=
out_ch
*
2
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
kernel_size
,
1
)),
kernel_size
=
(
kernel_size
,
1
)),
...
@@ -106,7 +107,7 @@ class ConvGLUBlock(nn.Layer):
...
@@ -106,7 +107,7 @@ class ConvGLUBlock(nn.Layer):
dim
=
0
)
dim
=
0
)
layers
[
'dropout_in'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'dropout_in'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'conv_bottleneck'
]
=
nn
.
utils
.
weight_norm
(
layers
[
'conv_bottleneck'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
Conv2D
(
in_channels
=
bottlececk_dim
,
in_channels
=
bottlececk_dim
,
out_channels
=
bottlececk_dim
,
out_channels
=
bottlececk_dim
,
kernel_size
=
(
kernel_size
,
1
)),
kernel_size
=
(
kernel_size
,
1
)),
...
@@ -115,7 +116,7 @@ class ConvGLUBlock(nn.Layer):
...
@@ -115,7 +116,7 @@ class ConvGLUBlock(nn.Layer):
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'dropout'
]
=
nn
.
Dropout
(
p
=
dropout
)
layers
[
'glu'
]
=
GLU
()
layers
[
'glu'
]
=
GLU
()
layers
[
'conv_out'
]
=
nn
.
utils
.
weight_norm
(
layers
[
'conv_out'
]
=
nn
.
utils
.
weight_norm
(
nn
.
Conv2D
(
Conv2D
(
in_channels
=
bottlececk_dim
,
in_channels
=
bottlececk_dim
,
out_channels
=
out_ch
*
2
,
out_channels
=
out_ch
*
2
,
kernel_size
=
(
1
,
1
)),
kernel_size
=
(
1
,
1
)),
...
...
paddlespeech/s2t/modules/align.py
0 → 100644
浏览文件 @
d53e1163
import
paddle
from
paddle
import
nn
from
paddlespeech.s2t.modules.initializer
import
KaimingUniform
"""
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
"""
global_init_type
=
None
class
LayerNorm
(
nn
.
LayerNorm
):
def
__init__
(
self
,
normalized_shape
,
epsilon
=
1e-05
,
weight_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
if
weight_attr
is
None
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
))
if
bias_attr
is
None
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
))
super
(
LayerNorm
,
self
).
__init__
(
normalized_shape
,
epsilon
,
weight_attr
,
bias_attr
,
name
)
class
BatchNorm1D
(
nn
.
BatchNorm1D
):
def
__init__
(
self
,
num_features
,
momentum
=
0.9
,
epsilon
=
1e-05
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCL'
,
name
=
None
):
if
weight_attr
is
None
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
))
if
bias_attr
is
None
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
))
super
(
BatchNorm1D
,
self
).
__init__
(
num_features
,
momentum
,
epsilon
,
weight_attr
,
bias_attr
,
data_format
,
name
)
class
Embedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
sparse
=
False
,
weight_attr
=
None
,
name
=
None
):
if
weight_attr
is
None
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Normal
())
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
sparse
,
weight_attr
,
name
)
class
Linear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
weight_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
super
(
Linear
,
self
).
__init__
(
in_features
,
out_features
,
weight_attr
,
bias_attr
,
name
)
class
Conv1D
(
nn
.
Conv1D
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCL'
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
print
(
"set kaiming_uniform"
)
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
super
(
Conv1D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
class
Conv2D
(
nn
.
Conv2D
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCHW'
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
KaimingUniform
())
super
(
Conv2D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
paddlespeech/s2t/modules/attention.py
浏览文件 @
d53e1163
...
@@ -22,6 +22,7 @@ import paddle
...
@@ -22,6 +22,7 @@ import paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
initializer
as
I
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -48,10 +49,10 @@ class MultiHeadedAttention(nn.Layer):
...
@@ -48,10 +49,10 @@ class MultiHeadedAttention(nn.Layer):
# We assume d_v always equals d_k
# We assume d_v always equals d_k
self
.
d_k
=
n_feat
//
n_head
self
.
d_k
=
n_feat
//
n_head
self
.
h
=
n_head
self
.
h
=
n_head
self
.
linear_q
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_q
=
Linear
(
n_feat
,
n_feat
)
self
.
linear_k
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_k
=
Linear
(
n_feat
,
n_feat
)
self
.
linear_v
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_v
=
Linear
(
n_feat
,
n_feat
)
self
.
linear_out
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_out
=
Linear
(
n_feat
,
n_feat
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
def
forward_qkv
(
self
,
def
forward_qkv
(
self
,
...
@@ -150,7 +151,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
...
@@ -150,7 +151,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""
"""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
# linear transformation for positional encoding
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
self
.
linear_pos
=
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
#self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
#self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
...
...
paddlespeech/s2t/modules/conformer_convolution.py
浏览文件 @
d53e1163
...
@@ -21,6 +21,9 @@ import paddle
...
@@ -21,6 +21,9 @@ import paddle
from
paddle
import
nn
from
paddle
import
nn
from
typeguard
import
check_argument_types
from
typeguard
import
check_argument_types
from
paddlespeech.s2t.modules.align
import
BatchNorm1D
from
paddlespeech.s2t.modules.align
import
Conv1D
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -49,7 +52,7 @@ class ConvolutionModule(nn.Layer):
...
@@ -49,7 +52,7 @@ class ConvolutionModule(nn.Layer):
"""
"""
assert
check_argument_types
()
assert
check_argument_types
()
super
().
__init__
()
super
().
__init__
()
self
.
pointwise_conv1
=
nn
.
Conv1D
(
self
.
pointwise_conv1
=
Conv1D
(
channels
,
channels
,
2
*
channels
,
2
*
channels
,
kernel_size
=
1
,
kernel_size
=
1
,
...
@@ -73,7 +76,7 @@ class ConvolutionModule(nn.Layer):
...
@@ -73,7 +76,7 @@ class ConvolutionModule(nn.Layer):
padding
=
(
kernel_size
-
1
)
//
2
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1D
(
self
.
depthwise_conv
=
Conv1D
(
channels
,
channels
,
channels
,
channels
,
kernel_size
,
kernel_size
,
...
@@ -87,22 +90,12 @@ class ConvolutionModule(nn.Layer):
...
@@ -87,22 +90,12 @@ class ConvolutionModule(nn.Layer):
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1D
(
self
.
norm
=
BatchNorm1D
(
channels
)
channels
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
else
:
else
:
self
.
use_layer_norm
=
True
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
self
.
norm
=
LayerNorm
(
channels
)
channels
,
weight_attr
=
paddle
.
ParamAttr
(
self
.
pointwise_conv2
=
Conv1D
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
pointwise_conv2
=
nn
.
Conv1D
(
channels
,
channels
,
channels
,
channels
,
kernel_size
=
1
,
kernel_size
=
1
,
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
d53e1163
...
@@ -18,6 +18,7 @@ from paddle import nn
...
@@ -18,6 +18,7 @@ from paddle import nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
typeguard
import
check_argument_types
from
typeguard
import
check_argument_types
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.loss
import
CTCLoss
from
paddlespeech.s2t.modules.loss
import
CTCLoss
from
paddlespeech.s2t.utils
import
ctc_utils
from
paddlespeech.s2t.utils
import
ctc_utils
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
...
@@ -69,7 +70,7 @@ class CTCDecoderBase(nn.Layer):
...
@@ -69,7 +70,7 @@ class CTCDecoderBase(nn.Layer):
self
.
blank_id
=
blank_id
self
.
blank_id
=
blank_id
self
.
odim
=
odim
self
.
odim
=
odim
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
ctc_lo
=
nn
.
Linear
(
enc_n_units
,
self
.
odim
)
self
.
ctc_lo
=
Linear
(
enc_n_units
,
self
.
odim
)
reduction_type
=
"sum"
if
reduction
else
"none"
reduction_type
=
"sum"
if
reduction
else
"none"
self
.
criterion
=
CTCLoss
(
self
.
criterion
=
CTCLoss
(
blank
=
self
.
blank_id
,
blank
=
self
.
blank_id
,
...
...
paddlespeech/s2t/modules/decoder.py
浏览文件 @
d53e1163
...
@@ -24,6 +24,9 @@ from paddle import nn
...
@@ -24,6 +24,9 @@ from paddle import nn
from
typeguard
import
check_argument_types
from
typeguard
import
check_argument_types
from
paddlespeech.s2t.decoders.scorers.scorer_interface
import
BatchScorerInterface
from
paddlespeech.s2t.decoders.scorers.scorer_interface
import
BatchScorerInterface
from
paddlespeech.s2t.modules.align
import
Embedding
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.decoder_layer
import
DecoderLayer
from
paddlespeech.s2t.modules.decoder_layer
import
DecoderLayer
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
...
@@ -83,25 +86,15 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -83,25 +86,15 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
if
input_layer
==
"embed"
:
if
input_layer
==
"embed"
:
self
.
embed
=
nn
.
Sequential
(
self
.
embed
=
nn
.
Sequential
(
nn
.
Embedding
(
Embedding
(
vocab_size
,
attention_dim
),
vocab_size
,
attention_dim
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Normal
())),
PositionalEncoding
(
attention_dim
,
positional_dropout_rate
),
)
PositionalEncoding
(
attention_dim
,
positional_dropout_rate
),
)
else
:
else
:
raise
ValueError
(
f
"only 'embed' is supported:
{
input_layer
}
"
)
raise
ValueError
(
f
"only 'embed' is supported:
{
input_layer
}
"
)
self
.
normalize_before
=
normalize_before
self
.
normalize_before
=
normalize_before
self
.
after_norm
=
nn
.
LayerNorm
(
self
.
after_norm
=
LayerNorm
(
attention_dim
,
epsilon
=
1e-12
)
attention_dim
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
use_output_layer
=
use_output_layer
self
.
use_output_layer
=
use_output_layer
self
.
output_layer
=
nn
.
Linear
(
attention_dim
,
vocab_size
)
self
.
output_layer
=
Linear
(
attention_dim
,
vocab_size
)
self
.
decoders
=
nn
.
LayerList
([
self
.
decoders
=
nn
.
LayerList
([
DecoderLayer
(
DecoderLayer
(
...
...
paddlespeech/s2t/modules/decoder_layer.py
浏览文件 @
d53e1163
...
@@ -20,6 +20,8 @@ from typing import Tuple
...
@@ -20,6 +20,8 @@ from typing import Tuple
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -62,32 +64,14 @@ class DecoderLayer(nn.Layer):
...
@@ -62,32 +64,14 @@ class DecoderLayer(nn.Layer):
self
.
self_attn
=
self_attn
self
.
self_attn
=
self_attn
self
.
src_attn
=
src_attn
self
.
src_attn
=
src_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
self
.
norm1
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
size
,
self
.
norm2
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
epsilon
=
1e-12
,
self
.
norm3
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
norm2
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
norm3
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
normalize_before
=
normalize_before
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_after
=
concat_after
self
.
concat_linear1
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
concat_linear1
=
Linear
(
size
+
size
,
size
)
self
.
concat_linear2
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
concat_linear2
=
Linear
(
size
+
size
,
size
)
def
forward
(
def
forward
(
self
,
self
,
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
d53e1163
...
@@ -23,6 +23,8 @@ from paddle import nn
...
@@ -23,6 +23,8 @@ from paddle import nn
from
typeguard
import
check_argument_types
from
typeguard
import
check_argument_types
from
paddlespeech.s2t.modules.activation
import
get_activation
from
paddlespeech.s2t.modules.activation
import
get_activation
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
MultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
from
paddlespeech.s2t.modules.conformer_convolution
import
ConvolutionModule
...
@@ -129,13 +131,7 @@ class BaseEncoder(nn.Layer):
...
@@ -129,13 +131,7 @@ class BaseEncoder(nn.Layer):
d_model
=
output_size
,
dropout_rate
=
positional_dropout_rate
),
)
d_model
=
output_size
,
dropout_rate
=
positional_dropout_rate
),
)
self
.
normalize_before
=
normalize_before
self
.
normalize_before
=
normalize_before
self
.
after_norm
=
nn
.
LayerNorm
(
self
.
after_norm
=
LayerNorm
(
output_size
,
epsilon
=
1e-12
)
output_size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
static_chunk_size
=
static_chunk_size
self
.
static_chunk_size
=
static_chunk_size
self
.
use_dynamic_chunk
=
use_dynamic_chunk
self
.
use_dynamic_chunk
=
use_dynamic_chunk
self
.
use_dynamic_left_chunk
=
use_dynamic_left_chunk
self
.
use_dynamic_left_chunk
=
use_dynamic_left_chunk
...
...
paddlespeech/s2t/modules/encoder_layer.py
浏览文件 @
d53e1163
...
@@ -20,6 +20,8 @@ from typing import Tuple
...
@@ -20,6 +20,8 @@ from typing import Tuple
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -59,15 +61,15 @@ class TransformerEncoderLayer(nn.Layer):
...
@@ -59,15 +61,15 @@ class TransformerEncoderLayer(nn.Layer):
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
norm1
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
norm2
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
norm2
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_after
=
concat_after
# concat_linear may be not used in forward fuction,
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
# but will be saved in the *.pt
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
concat_linear
=
Linear
(
size
+
size
,
size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -174,51 +176,23 @@ class ConformerEncoderLayer(nn.Layer):
...
@@ -174,51 +176,23 @@ class ConformerEncoderLayer(nn.Layer):
self
.
feed_forward
=
feed_forward
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
conv_module
=
conv_module
self
.
norm_ff
=
nn
.
LayerNorm
(
self
.
norm_ff
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the FNN module
size
,
self
.
norm_mha
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
# for the MHA module
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
# for the FNN module
self
.
norm_mha
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
# for the MHA module
if
feed_forward_macaron
is
not
None
:
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
nn
.
LayerNorm
(
self
.
norm_ff_macaron
=
LayerNorm
(
size
,
epsilon
=
1e-12
)
size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
self
.
ff_scale
=
0.5
self
.
ff_scale
=
0.5
else
:
else
:
self
.
ff_scale
=
1.0
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
nn
.
LayerNorm
(
self
.
norm_conv
=
LayerNorm
(
size
,
size
,
epsilon
=
1e-12
)
# for the CNN module
epsilon
=
1e-12
,
self
.
norm_final
=
LayerNorm
(
weight_attr
=
paddle
.
ParamAttr
(
size
,
epsilon
=
1e-12
)
# for the final output of the block
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
# for the CNN module
self
.
norm_final
=
nn
.
LayerNorm
(
size
,
epsilon
=
1e-12
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
)))
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_after
=
concat_after
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
concat_linear
=
Linear
(
size
+
size
,
size
)
def
forward
(
def
forward
(
self
,
self
,
...
...
paddlespeech/s2t/modules/initializer.py
浏览文件 @
d53e1163
...
@@ -11,93 +11,35 @@
...
@@ -11,93 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
paddle.fluid
import
framework
from
paddle.fluid.framework
import
in_dygraph_mode
,
default_main_program
import
numpy
as
np
import
numpy
as
np
from
paddle.fluid.core
import
VarDesc
from
paddle
import
nn
from
paddle.fluid
import
framework
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
paddle.fluid.core
import
VarDesc
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.initializer
import
Initializer
from
paddle.fluid.initializer
import
MSRAInitializer
from
typeguard
import
check_argument_types
__all__
=
[
__all__
=
[
'KaimingUniform'
]
'MSRAInitializer'
]
class
Initializer
(
object
):
"""Base class for variable initializers
Defines the common interface of variable initializers.
They add operations to the init program that are used
to initialize variables. Users should not use this class
directly, but need to use one of its implementations.
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
param
,
block
=
None
):
"""Add corresponding initialization operations to the network
"""
raise
NotImplementedError
()
def
_check_block
(
self
,
block
):
if
block
is
None
:
block
=
default_main_program
().
global_block
()
return
block
def
_compute_fans
(
self
,
var
):
"""Compute the fan_in and the fan_out for layers
This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.
Args:
var: variable for which fan_in and fan_out have to be computed
Returns:
tuple of two integers (fan_in, fan_out)
"""
shape
=
var
.
shape
if
not
shape
or
len
(
shape
)
==
0
:
fan_in
=
fan_out
=
1
elif
len
(
shape
)
==
1
:
fan_in
=
fan_out
=
shape
[
0
]
elif
len
(
shape
)
==
2
:
# This is the case for simple matrix multiply
fan_in
=
shape
[
0
]
fan_out
=
shape
[
1
]
else
:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size
=
np
.
prod
(
shape
[
2
:])
fan_in
=
shape
[
1
]
*
receptive_field_size
fan_out
=
shape
[
0
]
*
receptive_field_size
return
(
fan_in
,
fan_out
)
class
MSRAInitializer
(
Initializer
):
class
KaimingUniform
(
MSRA
Initializer
):
r
"""Implements the
MSRA initializer a.k.a. Kaiming I
nitializer
r
"""Implements the
Kaiming Uniform i
nitializer
This class implements the weight initialization from the paper
This class implements the weight initialization from the paper
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
robust initialization method that particularly considers the rectifier
robust initialization method that particularly considers the rectifier
nonlinearities. In case of Uniform distribution, the range is [-x, x], where
nonlinearities.
In case of Uniform distribution, the range is [-x, x], where
.. math::
.. math::
x = \sqrt{\
\frac{6
.0}{fan\_in}}
x = \sqrt{\
frac{1
.0}{fan\_in}}
In case of Normal distribution, the mean is 0 and the standard deviation
In case of Normal distribution, the mean is 0 and the standard deviation
is
is
...
@@ -107,10 +49,8 @@ class MSRAInitializer(Initializer):
...
@@ -107,10 +49,8 @@ class MSRAInitializer(Initializer):
\sqrt{\\frac{2.0}{fan\_in}}
\sqrt{\\frac{2.0}{fan\_in}}
Args:
Args:
uniform (bool): whether to use uniform or normal distribution
fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\
fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\
inferred from the variable. default is None.
inferred from the variable. default is None.
seed (int32): random seed
Note:
Note:
It is recommended to set fan_in to None for most cases.
It is recommended to set fan_in to None for most cases.
...
@@ -119,23 +59,19 @@ class MSRAInitializer(Initializer):
...
@@ -119,23 +59,19 @@ class MSRAInitializer(Initializer):
.. code-block:: python
.. code-block:: python
import paddle
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
paddle.enable_static()
x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32")
linear = nn.Linear(2,
fc = fluid.layers.fc(input=x, size=10,
4,
param_attr=fluid.initializer.MSRA(uniform=False))
weight_attr=nn.initializer.KaimingUniform())
data = paddle.rand([30, 10, 2], dtype='float32')
res = linear(data)
"""
"""
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
seed
=
0
):
def
__init__
(
self
,
fan_in
=
None
):
"""Constructor for MSRAInitializer
super
(
KaimingUniform
,
self
).
__init__
(
"""
uniform
=
True
,
fan_in
=
fan_in
,
seed
=
0
)
assert
uniform
is
not
None
assert
seed
is
not
None
super
(
MSRAInitializer
,
self
).
__init__
()
self
.
_uniform
=
uniform
self
.
_fan_in
=
fan_in
self
.
_seed
=
seed
def
__call__
(
self
,
var
,
block
=
None
):
def
__call__
(
self
,
var
,
block
=
None
):
"""Initialize the input tensor with MSRA initialization.
"""Initialize the input tensor with MSRA initialization.
...
@@ -165,8 +101,8 @@ class MSRAInitializer(Initializer):
...
@@ -165,8 +101,8 @@ class MSRAInitializer(Initializer):
var
.
dtype
==
VarDesc
.
VarType
.
BF16
and
not
self
.
_uniform
):
var
.
dtype
==
VarDesc
.
VarType
.
BF16
and
not
self
.
_uniform
):
out_dtype
=
VarDesc
.
VarType
.
FP32
out_dtype
=
VarDesc
.
VarType
.
FP32
out_var
=
block
.
create_var
(
out_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
name
=
unique_name
.
generate
(
[
'masra_init'
,
var
.
name
,
'tmp'
])),
"."
.
join
(
[
'masra_init'
,
var
.
name
,
'tmp'
])),
shape
=
var
.
shape
,
shape
=
var
.
shape
,
dtype
=
out_dtype
,
dtype
=
out_dtype
,
type
=
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
VarDesc
.
VarType
.
LOD_TENSOR
,
...
@@ -217,56 +153,23 @@ class MSRAInitializer(Initializer):
...
@@ -217,56 +153,23 @@ class MSRAInitializer(Initializer):
var
.
op
=
op
var
.
op
=
op
return
op
return
op
class
KaimingUniform
(
MSRAInitializer
):
r
"""Implements the Kaiming Uniform initializer
This class implements the weight initialization from the paper
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
robust initialization method that particularly considers the rectifier
nonlinearities.
In case of Uniform distribution, the range is [-x, x], where
.. math::
x = \sqrt{\frac{6.0}{fan\_in}}
Args:
fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\
inferred from the variable. default is None.
Note:
It is recommended to set fan_in to None for most cases.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
linear = nn.Linear(2,
4,
weight_attr=nn.initializer.KaimingUniform())
data = paddle.rand([30, 10, 2], dtype='float32')
res = linear(data)
class
DefaultInitializerContext
(
object
):
"""
"""
egs:
def
__init__
(
self
,
fan_in
=
None
):
with DefaultInitializerContext("kaiming_uniform"):
super
(
KaimingUniform
,
self
).
__init__
(
code for setup_model
uniform
=
True
,
fan_in
=
fan_in
,
seed
=
0
)
"""
def
__init__
(
self
,
init_type
=
None
):
self
.
init_type
=
init_type
def
__enter__
(
self
):
from
paddlespeech.s2t.modules
import
align
align
.
global_init_type
=
self
.
init_type
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
from
paddlespeech.s2t.modules
import
align
align
.
global_init_type
=
None
# We short the class name, since users will use the initializer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# hidden = fluid.layers.fc(...,
# param_attr=ParamAttr(fluid.initializer.Xavier()))
#
# It is no need to add an `Initializer` as the class suffix
MSRA
=
MSRAInitializer
paddlespeech/s2t/modules/nets_utils.py
已删除
100644 → 0
浏览文件 @
ab16d8ce
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from
paddle
import
nn
from
typeguard
import
check_argument_types
def
initialize
(
model
:
nn
.
Layer
,
init
:
str
):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
Args:
model (nn.Layer): Target.
init (str): Method of initialization.
"""
assert
check_argument_types
()
if
init
==
"xavier_uniform"
:
nn
.
initializer
.
set_global_initializer
(
nn
.
initializer
.
XavierUniform
(),
nn
.
initializer
.
Constant
())
elif
init
==
"xavier_normal"
:
nn
.
initializer
.
set_global_initializer
(
nn
.
initializer
.
XavierNormal
(),
nn
.
initializer
.
Constant
())
elif
init
==
"kaiming_uniform"
:
nn
.
initializer
.
set_global_initializer
(
nn
.
initializer
.
KaimingUniform
(),
nn
.
initializer
.
KaimingUniform
())
elif
init
==
"kaiming_normal"
:
nn
.
initializer
.
set_global_initializer
(
nn
.
initializer
.
KaimingNormal
(),
nn
.
initializer
.
Constant
())
else
:
raise
ValueError
(
"Unknown initialization: "
+
init
)
paddlespeech/s2t/modules/positionwise_feed_forward.py
浏览文件 @
d53e1163
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -44,10 +45,10 @@ class PositionwiseFeedForward(nn.Layer):
...
@@ -44,10 +45,10 @@ class PositionwiseFeedForward(nn.Layer):
activation (paddle.nn.Layer): Activation function
activation (paddle.nn.Layer): Activation function
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
w_1
=
nn
.
Linear
(
idim
,
hidden_units
)
self
.
w_1
=
Linear
(
idim
,
hidden_units
)
self
.
activation
=
activation
self
.
activation
=
activation
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
w_2
=
nn
.
Linear
(
hidden_units
,
idim
)
self
.
w_2
=
Linear
(
hidden_units
,
idim
)
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Forward function.
"""Forward function.
...
...
paddlespeech/s2t/modules/subsampling.py
浏览文件 @
d53e1163
...
@@ -19,6 +19,9 @@ from typing import Tuple
...
@@ -19,6 +19,9 @@ from typing import Tuple
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.s2t.modules.align
import
Conv2D
from
paddlespeech.s2t.modules.align
import
LayerNorm
from
paddlespeech.s2t.modules.align
import
Linear
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
from
paddlespeech.s2t.modules.embedding
import
PositionalEncoding
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
...
@@ -60,8 +63,8 @@ class LinearNoSubsampling(BaseSubsampling):
...
@@ -60,8 +63,8 @@ class LinearNoSubsampling(BaseSubsampling):
"""
"""
super
().
__init__
(
pos_enc_class
)
super
().
__init__
(
pos_enc_class
)
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
odim
),
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),
)
nn
.
ReLU
(),
)
self
.
right_context
=
0
self
.
right_context
=
0
...
@@ -108,12 +111,12 @@ class Conv2dSubsampling4(Conv2dSubsampling):
...
@@ -108,12 +111,12 @@ class Conv2dSubsampling4(Conv2dSubsampling):
"""
"""
super
().
__init__
(
pos_enc_class
)
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
nn
.
ReLU
(),
)
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
))
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
))
self
.
subsampling_rate
=
4
self
.
subsampling_rate
=
4
# The right context for every conv layer is computed by:
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
# (kernel_size - 1) * frame_rate_of_this_layer
...
@@ -160,13 +163,13 @@ class Conv2dSubsampling6(Conv2dSubsampling):
...
@@ -160,13 +163,13 @@ class Conv2dSubsampling6(Conv2dSubsampling):
"""
"""
super
().
__init__
(
pos_enc_class
)
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
5
,
3
),
Conv2D
(
odim
,
odim
,
5
,
3
),
nn
.
ReLU
(),
)
nn
.
ReLU
(),
)
# O = (I - F + Pstart + Pend) // S + 1
# O = (I - F + Pstart + Pend) // S + 1
# when Padding == 0, O = (I - F - S) // S
# when Padding == 0, O = (I - F - S) // S
self
.
linear
=
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
)
self
.
linear
=
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
)
# The right context for every conv layer is computed by:
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
# (kernel_size - 1) * frame_rate_of_this_layer
# 10 = (3 - 1) * 1 + (5 - 1) * 2
# 10 = (3 - 1) * 1 + (5 - 1) * 2
...
@@ -212,14 +215,14 @@ class Conv2dSubsampling8(Conv2dSubsampling):
...
@@ -212,14 +215,14 @@ class Conv2dSubsampling8(Conv2dSubsampling):
"""
"""
super
().
__init__
(
pos_enc_class
)
super
().
__init__
(
pos_enc_class
)
self
.
conv
=
nn
.
Sequential
(
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
1
,
odim
,
3
,
2
),
Conv2D
(
1
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2D
(
odim
,
odim
,
3
,
2
),
Conv2D
(
odim
,
odim
,
3
,
2
),
nn
.
ReLU
(),
)
nn
.
ReLU
(),
)
self
.
linear
=
nn
.
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
self
.
linear
=
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
odim
)
odim
)
self
.
subsampling_rate
=
8
self
.
subsampling_rate
=
8
# The right context for every conv layer is computed by:
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
# (kernel_size - 1) * frame_rate_of_this_layer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录