Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
6068374a
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6068374a
编写于
2月 06, 2020
作者:
L
lifuchen
提交者:
chenfeiyu
2月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified fastspeech network
上级
47a618ce
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
25 addition
and
14 deletion
+25
-14
parakeet/models/fastspeech/config/fastspeech.yaml
parakeet/models/fastspeech/config/fastspeech.yaml
+2
-2
parakeet/models/fastspeech/modules.py
parakeet/models/fastspeech/modules.py
+1
-1
parakeet/models/fastspeech/network.py
parakeet/models/fastspeech/network.py
+2
-1
parakeet/models/fastspeech/train.py
parakeet/models/fastspeech/train.py
+0
-1
parakeet/modules/multihead_attention.py
parakeet/modules/multihead_attention.py
+11
-6
parakeet/modules/post_convnet.py
parakeet/modules/post_convnet.py
+9
-3
未找到文件。
parakeet/models/fastspeech/config/fastspeech.yaml
浏览文件 @
6068374a
...
...
@@ -3,8 +3,8 @@ audio:
n_fft
:
2048
sr
:
22050
preemphasis
:
0.97
hop_length
:
2
75
win_length
:
1
102
hop_length
:
2
56
win_length
:
1
024
power
:
1.2
min_level_db
:
-100
ref_level_db
:
20
...
...
parakeet/models/fastspeech/modules.py
浏览文件 @
6068374a
...
...
@@ -11,7 +11,7 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
class
FFTBlock
(
dg
.
Layer
):
def
__init__
(
self
,
d_model
,
d_inner
,
n_head
,
d_k
,
d_v
,
filter_size
,
padding
,
dropout
=
0.2
):
super
(
FFTBlock
,
self
).
__init__
()
self
.
slf_attn
=
MultiheadAttention
(
d_model
,
d_k
,
d_v
,
num_head
=
n_head
,
dropout
=
dropout
)
self
.
slf_attn
=
MultiheadAttention
(
d_model
,
d_k
,
d_v
,
num_head
=
n_head
,
is_bias
=
True
,
dropout
=
dropout
,
is_concat
=
False
)
self
.
pos_ffn
=
PositionwiseFeedForward
(
d_model
,
d_inner
,
filter_size
=
filter_size
,
padding
=
padding
,
dropout
=
dropout
)
def
forward
(
self
,
enc_input
,
non_pad_mask
=
None
,
slf_attn_mask
=
None
):
...
...
parakeet/models/fastspeech/network.py
浏览文件 @
6068374a
...
...
@@ -161,7 +161,8 @@ class FastSpeech(dg.Layer):
num_conv
=
5
,
outputs_per_step
=
cfg
.
audio
.
outputs_per_step
,
use_cudnn
=
True
,
dropout
=
0.1
)
dropout
=
0.1
,
batchnorm_last
=
True
)
def
forward
(
self
,
character
,
text_pos
,
mel_pos
=
None
,
length_target
=
None
,
alpha
=
1.0
):
"""
...
...
parakeet/models/fastspeech/train.py
浏览文件 @
6068374a
...
...
@@ -29,7 +29,6 @@ def load_checkpoint(step, model_path):
return
new_state_dict
,
opti_dict
def
main
(
cfg
):
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
cfg
.
use_data_parallel
else
0
nranks
=
dg
.
parallel
.
Env
().
nranks
if
cfg
.
use_data_parallel
else
1
...
...
parakeet/modules/multihead_attention.py
浏览文件 @
6068374a
...
...
@@ -47,21 +47,25 @@ class ScaledDotProductAttention(dg.Layer):
return
result
,
attention
class
MultiheadAttention
(
dg
.
Layer
):
def
__init__
(
self
,
num_hidden
,
d_k
,
d_q
,
num_head
=
4
,
dropout
=
0.1
):
def
__init__
(
self
,
num_hidden
,
d_k
,
d_q
,
num_head
=
4
,
is_bias
=
False
,
dropout
=
0.1
,
is_concat
=
True
):
super
(
MultiheadAttention
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
self
.
d_k
=
d_k
self
.
d_q
=
d_q
self
.
dropout
=
dropout
self
.
is_concat
=
is_concat
self
.
key
=
Linear
(
num_hidden
,
num_head
*
d_k
,
is_bias
=
False
)
self
.
value
=
Linear
(
num_hidden
,
num_head
*
d_k
,
is_bias
=
False
)
self
.
query
=
Linear
(
num_hidden
,
num_head
*
d_q
,
is_bias
=
False
)
self
.
key
=
Linear
(
num_hidden
,
num_head
*
d_k
,
is_bias
=
is_bias
)
self
.
value
=
Linear
(
num_hidden
,
num_head
*
d_k
,
is_bias
=
is_bias
)
self
.
query
=
Linear
(
num_hidden
,
num_head
*
d_q
,
is_bias
=
is_bias
)
self
.
scal_attn
=
ScaledDotProductAttention
(
d_k
)
self
.
fc
=
Linear
(
num_head
*
d_q
*
2
,
num_hidden
)
if
self
.
is_concat
:
self
.
fc
=
Linear
(
num_head
*
d_q
*
2
,
num_hidden
)
else
:
self
.
fc
=
Linear
(
num_head
*
d_q
,
num_hidden
)
self
.
layer_norm
=
dg
.
LayerNorm
(
num_hidden
)
...
...
@@ -105,7 +109,8 @@ class MultiheadAttention(dg.Layer):
# concat all multihead result
result
=
layers
.
reshape
(
result
,
[
self
.
num_head
,
batch_size
,
seq_len_query
,
self
.
d_q
])
result
=
layers
.
reshape
(
layers
.
transpose
(
result
,
[
1
,
2
,
0
,
3
]),[
batch_size
,
seq_len_query
,
-
1
])
result
=
layers
.
concat
([
query_input
,
result
],
axis
=-
1
)
if
self
.
is_concat
:
result
=
layers
.
concat
([
query_input
,
result
],
axis
=-
1
)
result
=
layers
.
dropout
(
self
.
fc
(
result
),
self
.
dropout
)
result
=
result
+
query_input
...
...
parakeet/modules/post_convnet.py
浏览文件 @
6068374a
...
...
@@ -12,11 +12,13 @@ class PostConvNet(dg.Layer):
num_conv
=
5
,
outputs_per_step
=
1
,
use_cudnn
=
True
,
dropout
=
0.1
):
dropout
=
0.1
,
batchnorm_last
=
False
):
super
(
PostConvNet
,
self
).
__init__
()
self
.
dropout
=
dropout
self
.
num_conv
=
num_conv
self
.
batchnorm_last
=
batchnorm_last
self
.
conv_list
=
[]
self
.
conv_list
.
append
(
Conv
(
in_channels
=
n_mels
*
outputs_per_step
,
out_channels
=
num_hidden
,
...
...
@@ -45,8 +47,9 @@ class PostConvNet(dg.Layer):
self
.
batch_norm_list
=
[
dg
.
BatchNorm
(
num_hidden
,
data_layout
=
'NCHW'
)
for
_
in
range
(
num_conv
-
1
)]
#self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
# data_layout='NCHW'))
if
self
.
batchnorm_last
:
self
.
batch_norm_list
.
append
(
dg
.
BatchNorm
(
n_mels
*
outputs_per_step
,
data_layout
=
'NCHW'
))
for
i
,
layer
in
enumerate
(
self
.
batch_norm_list
):
self
.
add_sublayer
(
"batch_norm_list_{}"
.
format
(
i
),
layer
)
...
...
@@ -70,5 +73,8 @@ class PostConvNet(dg.Layer):
input
=
layers
.
dropout
(
layers
.
tanh
(
batch_norm
(
conv
(
input
)[:,:,:
len
])),
self
.
dropout
)
conv
=
self
.
conv_list
[
self
.
num_conv
-
1
]
input
=
conv
(
input
)[:,:,:
len
]
if
self
.
batchnorm_last
:
batch_norm
=
self
.
batch_norm_list
[
self
.
num_conv
-
1
]
input
=
layers
.
dropout
(
batch_norm
(
input
),
self
.
dropout
)
output
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
return
output
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录