Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
a4dd5acc
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a4dd5acc
编写于
5月 08, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into 'master'
update for paddle 1.8.0 See merge request !51
上级
8505805d
e0ba85f6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
39 addition
and
43 deletion
+39
-43
README.md
README.md
+1
-1
examples/clarinet/train.py
examples/clarinet/train.py
+5
-5
examples/deepvoice3/train.py
examples/deepvoice3/train.py
+8
-4
examples/wavenet/train.py
examples/wavenet/train.py
+6
-7
parakeet/models/clarinet/utils.py
parakeet/models/clarinet/utils.py
+6
-18
parakeet/models/deepvoice3/loss.py
parakeet/models/deepvoice3/loss.py
+2
-2
parakeet/models/deepvoice3/position_embedding.py
parakeet/models/deepvoice3/position_embedding.py
+4
-2
parakeet/models/wavenet/wavenet.py
parakeet/models/wavenet/wavenet.py
+3
-2
parakeet/modules/weight_norm.py
parakeet/modules/weight_norm.py
+4
-2
未找到文件。
README.md
浏览文件 @
a4dd5acc
...
...
@@ -40,7 +40,7 @@ sudo apt-get install libsndfile1
### Install PaddlePaddle
See
[
install
](
https://www.paddlepaddle.org.cn/install/quick
)
for more details. This repo requires PaddlePaddle
**1.
7.1
**
or above.
See
[
install
](
https://www.paddlepaddle.org.cn/install/quick
)
for more details. This repo requires PaddlePaddle
**1.
8.0
**
or above.
### Install Parakeet
...
...
examples/clarinet/train.py
浏览文件 @
a4dd5acc
...
...
@@ -163,11 +163,11 @@ if __name__ == "__main__":
anneal_interval
=
train_config
[
"anneal_interval"
]
lr_scheduler
=
dg
.
ExponentialDecay
(
learning_rate
,
anneal_interval
,
anneal_rate
,
staircase
=
True
)
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
parameter_list
=
model
.
parameters
())
gradiant_max_norm
=
train_config
[
"gradient_max_norm"
]
clipper
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
gradiant_max_norm
)
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
parameter_list
=
model
.
parameters
(),
grad_clip
=
fluid
.
clip
.
ClipByGlobalNorm
(
gradiant_max_norm
))
# train
max_iterations
=
train_config
[
"max_iterations"
]
...
...
@@ -229,7 +229,7 @@ if __name__ == "__main__":
step_loss
))
l
.
backward
()
optim
.
minimize
(
l
,
grad_clip
=
clipper
)
optim
.
minimize
(
l
)
optim
.
clear_gradients
()
if
global_step
%
eval_interval
==
0
:
...
...
examples/deepvoice3/train.py
浏览文件 @
a4dd5acc
...
...
@@ -196,8 +196,8 @@ if __name__ == "__main__":
beta1
,
beta2
,
epsilon
=
epsilon
,
parameter_list
=
dv3
.
parameters
()
)
gradient_clipper
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
0.1
)
parameter_list
=
dv3
.
parameters
()
,
grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
0.1
)
)
# generation
synthesis_config
=
config
[
"synthesis"
]
...
...
@@ -258,15 +258,19 @@ if __name__ == "__main__":
text_lengths
,
frames
)
l
=
losses
[
"loss"
]
l
.
backward
()
# record learning rate before updating
writer
.
add_scalar
(
"learning_rate"
,
optim
.
_learning_rate
.
step
().
numpy
(),
global_step
)
optim
.
minimize
(
l
,
grad_clip
=
gradient_clipper
)
optim
.
minimize
(
l
)
optim
.
clear_gradients
()
# ==================all kinds of tedious things=================
# record step loss into tensorboard
step_loss
=
{
k
:
v
.
numpy
()[
0
]
for
k
,
v
in
losses
.
items
()}
step_loss
=
{
k
:
v
.
numpy
()[
0
]
for
k
,
v
in
losses
.
items
()
if
v
is
not
None
}
tqdm
.
tqdm
.
write
(
"global_step: {}
\t
loss: {}"
.
format
(
global_step
,
step_loss
[
"loss"
]))
for
k
,
v
in
step_loss
.
items
():
...
...
examples/wavenet/train.py
浏览文件 @
a4dd5acc
...
...
@@ -126,12 +126,11 @@ if __name__ == "__main__":
anneal_interval
=
train_config
[
"anneal_interval"
]
lr_scheduler
=
dg
.
ExponentialDecay
(
learning_rate
,
anneal_interval
,
anneal_rate
,
staircase
=
True
)
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
parameter_list
=
model
.
parameters
())
gradiant_max_norm
=
train_config
[
"gradient_max_norm"
]
clipper
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
gradiant_max_norm
)
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
parameter_list
=
model
.
parameters
(),
grad_clip
=
fluid
.
clip
.
ClipByGlobalNorm
(
gradiant_max_norm
))
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
,
return_list
=
True
)
...
...
@@ -149,7 +148,7 @@ if __name__ == "__main__":
log_dir
=
os
.
path
.
join
(
args
.
output
,
"log"
)
writer
=
SummaryWriter
(
log_dir
)
# load parameters and optimizer, and
opdate iterations done so
far
# load parameters and optimizer, and
update iterations done so
far
if
args
.
checkpoint
is
not
None
:
iteration
=
io
.
load_parameters
(
model
,
optim
,
checkpoint_path
=
args
.
checkpoint
)
...
...
@@ -181,7 +180,7 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"learning_rate"
,
optim
.
_learning_rate
.
step
().
numpy
()[
0
],
global_step
)
optim
.
minimize
(
loss_var
,
grad_clip
=
clipper
)
optim
.
minimize
(
loss_var
)
optim
.
clear_gradients
()
print
(
"global_step: {}
\t
loss: {:<8.6f}"
.
format
(
global_step
,
loss_np
[
0
]))
...
...
parakeet/models/clarinet/utils.py
浏览文件 @
a4dd5acc
...
...
@@ -29,22 +29,10 @@ def conv2d(input,
data_format
=
"NCHW"
):
padding
=
tuple
(
pad
for
pad_dim
in
padding
for
pad
in
pad_dim
)
inputs
=
{
'Input'
:
[
input
],
'Filter'
:
[
weight
],
}
attrs
=
{
'strides'
:
stride
,
'paddings'
:
padding
,
'dilations'
:
dilation
,
'groups'
:
groups
,
'use_cudnn'
:
use_cudnn
,
'use_mkldnn'
:
False
,
'fuse_relu_before_depthwise_conv'
:
False
,
"padding_algorithm"
:
"EXPLICIT"
,
"data_format"
:
data_format
,
}
attrs
=
(
'strides'
,
stride
,
'paddings'
,
padding
,
'dilations'
,
dilation
,
'groups'
,
groups
,
'use_cudnn'
,
use_cudnn
,
'use_mkldnn'
,
False
,
'fuse_relu_before_depthwise_conv'
,
False
,
"padding_algorithm"
,
"EXPLICIT"
,
"data_format"
,
data_format
)
outputs
=
ops
.
conv2d
(
inputs
,
attrs
)
out
=
outputs
[
"Output"
][
0
]
return
out
\ No newline at end of file
out
=
ops
.
conv2d
(
input
,
weight
,
*
attrs
)
return
out
parakeet/models/deepvoice3/loss.py
浏览文件 @
a4dd5acc
...
...
@@ -262,7 +262,7 @@ class TTSLoss(object):
if
compute_lin_loss
:
lin_hyp
=
lin_hyp
[:,
:
-
self
.
time_shift
,
:]
lin_ref
=
lin_ref
[:,
self
.
time_shift
:,
:]
lin_mask
=
lin_mask
[:,
self
.
time_shift
:
,
:
]
lin_mask
=
lin_mask
[:,
self
.
time_shift
:]
lin_l1_loss
=
self
.
l1_loss
(
lin_hyp
,
lin_ref
,
lin_mask
,
priority_bin
=
self
.
priority_bin
)
lin_bce_loss
=
self
.
binary_divergence
(
lin_hyp
,
lin_ref
,
lin_mask
)
...
...
@@ -273,7 +273,7 @@ class TTSLoss(object):
if
compute_mel_loss
:
mel_hyp
=
mel_hyp
[:,
:
-
self
.
time_shift
,
:]
mel_ref
=
mel_ref
[:,
self
.
time_shift
:,
:]
mel_mask
=
mel_mask
[:,
self
.
time_shift
:
,
:
]
mel_mask
=
mel_mask
[:,
self
.
time_shift
:]
mel_l1_loss
=
self
.
l1_loss
(
mel_hyp
,
mel_ref
,
mel_mask
)
mel_bce_loss
=
self
.
binary_divergence
(
mel_hyp
,
mel_ref
,
mel_mask
)
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
...
...
parakeet/models/deepvoice3/position_embedding.py
浏览文件 @
a4dd5acc
...
...
@@ -31,8 +31,10 @@ def compute_position_embedding(radians, speaker_position_rate):
"""
_
,
embed_dim
=
radians
.
shape
batch_size
=
speaker_position_rate
.
shape
[
0
]
speaker_position_rate
=
F
.
unsqueeze
(
speaker_position_rate
,
[
1
,
2
])
scaled_radians
=
speaker_position_rate
*
radians
scaled_radians
=
F
.
elementwise_mul
(
F
.
expand
(
F
.
unsqueeze
(
radians
,
[
0
]),
[
batch_size
,
1
,
1
]),
speaker_position_rate
,
axis
=
0
)
odd_mask
=
(
np
.
arange
(
embed_dim
)
%
2
).
astype
(
np
.
float32
)
odd_mask
=
dg
.
to_variable
(
odd_mask
)
...
...
parakeet/models/wavenet/wavenet.py
浏览文件 @
a4dd5acc
...
...
@@ -111,7 +111,7 @@ class ResidualBlock(dg.Layer):
h
=
h
[:,
:,
:
time_steps
]
# condition
if
condition
:
if
condition
is
not
None
:
h
+=
self
.
condition_proj
(
condition
)
# gated tanh
...
...
@@ -398,7 +398,8 @@ class WaveNet(dg.Layer):
x_std
=
inv_std
*
(
t
-
mu
)
exponent
=
F
.
exp
(
-
0.5
*
x_std
*
x_std
)
pdf_x
=
1.0
/
np
.
sqrt
(
2.0
*
np
.
pi
)
*
inv_std
*
exponent
pdf_x
=
1.0
/
math
.
sqrt
(
2.0
*
math
.
pi
)
*
inv_std
*
exponent
pdf_x
=
p_mixture
*
pdf_x
# pdf_x: [bs, len]
pdf_x
=
F
.
reduce_sum
(
pdf_x
,
dim
=-
1
)
...
...
parakeet/modules/weight_norm.py
浏览文件 @
a4dd5acc
...
...
@@ -84,13 +84,15 @@ class WeightNormWrapper(dg.Layer):
w_v
,
self
.
create_parameter
(
shape
=
original_weight
.
shape
,
dtype
=
original_weight
.
dtype
))
F
.
assign
(
original_weight
,
getattr
(
self
,
w_v
))
with
dg
.
no_grad
():
F
.
assign
(
original_weight
,
getattr
(
self
,
w_v
))
delattr
(
layer
,
param_name
)
temp
=
norm_except
(
getattr
(
self
,
w_v
),
self
.
dim
,
self
.
power
)
self
.
add_parameter
(
w_g
,
self
.
create_parameter
(
shape
=
temp
.
shape
,
dtype
=
temp
.
dtype
))
F
.
assign
(
temp
,
getattr
(
self
,
w_g
))
with
dg
.
no_grad
():
F
.
assign
(
temp
,
getattr
(
self
,
w_g
))
# also set this when setting up
setattr
(
self
.
layer
,
self
.
param_name
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录