Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
f9d97852
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
14
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9d97852
编写于
2月 23, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update waveflow to 1.7 api and verified training
上级
43814acb
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
427 addition
and
291 deletion
+427
-291
examples/deepvoice3/train.py
examples/deepvoice3/train.py
+74
-74
examples/waveflow/README.md
examples/waveflow/README.md
+5
-5
examples/waveflow/benchmark.py
examples/waveflow/benchmark.py
+21
-9
examples/waveflow/configs/waveflow_ljspeech.yaml
examples/waveflow/configs/waveflow_ljspeech.yaml
+0
-0
examples/waveflow/synthesis.py
examples/waveflow/synthesis.py
+30
-12
examples/waveflow/train.py
examples/waveflow/train.py
+34
-18
examples/waveflow/utils.py
examples/waveflow/utils.py
+78
-36
parakeet/datasets/ljspeech.py
parakeet/datasets/ljspeech.py
+26
-19
parakeet/models/waveflow/__init__.py
parakeet/models/waveflow/__init__.py
+1
-0
parakeet/models/waveflow/data.py
parakeet/models/waveflow/data.py
+21
-21
parakeet/models/waveflow/waveflow.py
parakeet/models/waveflow/waveflow.py
+39
-26
parakeet/models/waveflow/waveflow_modules.py
parakeet/models/waveflow/waveflow_modules.py
+77
-58
parakeet/modules/weight_norm.py
parakeet/modules/weight_norm.py
+21
-13
未找到文件。
examples/deepvoice3/train.py
浏览文件 @
f9d97852
...
...
@@ -28,22 +28,21 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"Train a deepvoice 3 model with LJSpeech dataset."
)
parser
.
add_argument
(
"-c"
,
"--config"
,
type
=
str
,
help
=
"experimrnt config"
)
parser
.
add_argument
(
"-s"
,
parser
.
add_argument
(
"-s"
,
"--data"
,
type
=
str
,
default
=
"/workspace/datasets/LJSpeech-1.1/"
,
help
=
"The path of the LJSpeech dataset."
)
parser
.
add_argument
(
"-r"
,
"--resume"
,
type
=
str
,
help
=
"checkpoint to load"
)
parser
.
add_argument
(
"-o"
,
parser
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
default
=
"result"
,
help
=
"The directory to save result."
)
parser
.
add_argument
(
"-g"
,
"--device"
,
type
=
int
,
default
=-
1
,
help
=
"device to use"
)
parser
.
add_argument
(
"-g"
,
"--device"
,
type
=
int
,
default
=-
1
,
help
=
"device to use"
)
args
,
_
=
parser
.
parse_known_args
()
with
open
(
args
.
config
,
'rt'
)
as
f
:
config
=
ruamel
.
yaml
.
safe_load
(
f
)
...
...
@@ -84,18 +83,16 @@ if __name__ == "__main__":
train_config
=
config
[
"train"
]
batch_size
=
train_config
[
"batch_size"
]
text_lengths
=
[
len
(
example
[
2
])
for
example
in
meta
]
sampler
=
PartialyRandomizedSimilarTimeLengthSampler
(
text_lengths
,
batch_size
)
sampler
=
PartialyRandomizedSimilarTimeLengthSampler
(
text_lengths
,
batch_size
)
# some hyperparameters affect how we process data, so create a data collector!
model_config
=
config
[
"model"
]
downsample_factor
=
model_config
[
"downsample_factor"
]
r
=
model_config
[
"outputs_per_step"
]
collector
=
DataCollector
(
downsample_factor
=
downsample_factor
,
r
=
r
)
ljspeech_loader
=
DataCargo
(
ljspeech
,
batch_fn
=
collector
,
batch_size
=
batch_size
,
sampler
=
sampler
)
ljspeech_loader
=
DataCargo
(
ljspeech
,
batch_fn
=
collector
,
batch_size
=
batch_size
,
sampler
=
sampler
)
# =========================model=========================
if
args
.
device
==
-
1
:
...
...
@@ -131,15 +128,14 @@ if __name__ == "__main__":
window_ahead
=
model_config
[
"window_ahead"
]
key_projection
=
model_config
[
"key_projection"
]
value_projection
=
model_config
[
"value_projection"
]
dv3
=
make_model
(
n_speakers
,
speaker_dim
,
speaker_embed_std
,
embed_dim
,
padding_idx
,
embedding_std
,
max_positions
,
n_vocab
,
freeze_embedding
,
filter_size
,
encoder_channels
,
n_mels
,
decoder_channels
,
r
,
dv3
=
make_model
(
n_speakers
,
speaker_dim
,
speaker_embed_std
,
embed_dim
,
padding_idx
,
embedding_std
,
max_positions
,
n_vocab
,
freeze_embedding
,
filter_size
,
encoder_channels
,
n_mels
,
decoder_channels
,
r
,
trainable_positional_encodings
,
use_memory_mask
,
query_position_rate
,
key_position_rate
,
window_backward
,
window_ahead
,
key_projection
,
value_projection
,
downsample_factor
,
linear_dim
,
use_decoder_states
,
converter_channels
,
dropout
)
query_position_rate
,
key_position_rate
,
window_backward
,
window_ahead
,
key_projection
,
value_projection
,
downsample_factor
,
linear_dim
,
use_decoder_states
,
converter_channels
,
dropout
)
# =========================loss=========================
loss_config
=
config
[
"loss"
]
...
...
@@ -149,7 +145,8 @@ if __name__ == "__main__":
priority_freq_weight
=
loss_config
[
"priority_freq_weight"
]
binary_divergence_weight
=
loss_config
[
"binary_divergence_weight"
]
guided_attention_sigma
=
loss_config
[
"guided_attention_sigma"
]
criterion
=
TTSLoss
(
masked_weight
=
masked_weight
,
criterion
=
TTSLoss
(
masked_weight
=
masked_weight
,
priority_bin
=
priority_bin
,
priority_weight
=
priority_freq_weight
,
binary_divergence_weight
=
binary_divergence_weight
,
...
...
@@ -169,7 +166,8 @@ if __name__ == "__main__":
beta1
=
optim_config
[
"beta1"
]
beta2
=
optim_config
[
"beta2"
]
epsilon
=
optim_config
[
"epsilon"
]
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
optim
=
fluid
.
optimizer
.
Adam
(
lr_scheduler
,
beta1
,
beta2
,
epsilon
=
epsilon
,
...
...
@@ -183,8 +181,8 @@ if __name__ == "__main__":
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
,
return_list
=
True
)
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
,
return_list
=
True
)
loader
.
set_batch_generator
(
ljspeech_loader
,
places
=
place
)
# tensorboard & checkpoint preparation
...
...
@@ -247,7 +245,8 @@ if __name__ == "__main__":
# TODO: clean code
# train state saving, the first sentence in the batch
if
global_step
%
snap_interval
==
0
:
save_state
(
state_dir
,
save_state
(
state_dir
,
writer
,
global_step
,
mel_input
=
downsampled_mel_specs
,
...
...
@@ -275,16 +274,16 @@ if __name__ == "__main__":
"Some have accepted this as a miracle without any physical explanation."
,
]
for
idx
,
sent
in
enumerate
(
sentences
):
wav
,
attn
=
eval_model
(
dv3
,
sent
,
replace_pronounciation_prob
,
min_level_db
,
ref_level_db
,
power
,
n_iter
,
win_length
,
hop_length
,
preemphasis
)
wav
,
attn
=
eval_model
(
dv3
,
sent
,
replace_pronounciation_prob
,
min_level_db
,
ref_level_db
,
power
,
n_iter
,
win_length
,
hop_length
,
preemphasis
)
wav_path
=
os
.
path
.
join
(
state_dir
,
"waveform"
,
"eval_sample_{:09d}.wav"
.
format
(
global_step
))
sf
.
write
(
wav_path
,
wav
,
sample_rate
)
writer
.
add_audio
(
"eval_sample_{}"
.
format
(
idx
),
writer
.
add_audio
(
"eval_sample_{}"
.
format
(
idx
),
wav
,
global_step
,
sample_rate
=
sample_rate
)
...
...
@@ -292,7 +291,8 @@ if __name__ == "__main__":
state_dir
,
"alignments"
,
"eval_sample_attn_{:09d}.png"
.
format
(
global_step
))
plot_alignment
(
attn
,
attn_path
)
writer
.
add_image
(
"eval_sample_attn{}"
.
format
(
idx
),
writer
.
add_image
(
"eval_sample_attn{}"
.
format
(
idx
),
cm
.
viridis
(
attn
),
global_step
,
dataformats
=
"HWC"
)
...
...
parakeet/model
s/waveflow/README.md
→
example
s/waveflow/README.md
浏览文件 @
f9d97852
parakeet/model
s/waveflow/benchmark.py
→
example
s/waveflow/benchmark.py
浏览文件 @
f9d97852
...
...
@@ -12,20 +12,32 @@ from waveflow import WaveFlow
def
add_options_to_parser
(
parser
):
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
help
=
"general name of the model"
)
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
True
,
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
True
,
help
=
"option to use gpu training"
)
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
help
=
(
"which iteration of checkpoint to load, "
"default to load the latest checkpoint"
))
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
help
=
"path of the checkpoint to load"
)
...
...
parakeet/model
s/waveflow/configs/waveflow_ljspeech.yaml
→
example
s/waveflow/configs/waveflow_ljspeech.yaml
浏览文件 @
f9d97852
文件已移动
parakeet/model
s/waveflow/synthesis.py
→
example
s/waveflow/synthesis.py
浏览文件 @
f9d97852
...
...
@@ -12,25 +12,43 @@ from waveflow import WaveFlow
def
add_options_to_parser
(
parser
):
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
help
=
"general name of the model"
)
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
True
,
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
True
,
help
=
"option to use gpu training"
)
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
help
=
(
"which iteration of checkpoint to load, "
"default to load the latest checkpoint"
))
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
help
=
"path of the checkpoint to load"
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
"./syn_audios"
,
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
"./syn_audios"
,
help
=
"path to write synthesized audio files"
)
parser
.
add_argument
(
'--sample'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--sample'
,
type
=
int
,
default
=
None
,
help
=
"which of the valid samples to synthesize audio"
)
...
...
parakeet/model
s/waveflow/train.py
→
example
s/waveflow/train.py
浏览文件 @
f9d97852
...
...
@@ -4,34 +4,48 @@ import subprocess
import
time
from
pprint
import
pprint
import
json
argparse
import
argparse
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
from
paddle
import
fluid
from
tensorboardX
import
SummaryWriter
import
slurm
import
utils
from
waveflow
import
WaveFlow
from
parakeet.models.
waveflow
import
WaveFlow
def
add_options_to_parser
(
parser
):
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'waveflow'
,
help
=
"general name of the model"
)
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--parallel'
,
type
=
bool
,
default
=
True
,
parser
.
add_argument
(
'--name'
,
type
=
str
,
help
=
"specific name of the training model"
)
parser
.
add_argument
(
'--root'
,
type
=
str
,
help
=
"root path of the LJSpeech dataset"
)
parser
.
add_argument
(
'--parallel'
,
type
=
utils
.
str2bool
,
default
=
True
,
help
=
"option to use data parallel training"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
True
,
parser
.
add_argument
(
'--use_gpu'
,
type
=
utils
.
str2bool
,
default
=
True
,
help
=
"option to use gpu training"
)
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
'--iteration'
,
type
=
int
,
default
=
None
,
help
=
(
"which iteration of checkpoint to load, "
"default to load the latest checkpoint"
))
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
help
=
"path of the checkpoint to load"
)
...
...
@@ -45,12 +59,13 @@ def train(config):
if
rank
==
0
:
# Print the whole config setting.
pprint
(
jsonargparse
.
namespace_to_dict
(
config
))
pprint
(
vars
(
config
))
# Make checkpoint directory.
run_dir
=
os
.
path
.
join
(
"runs"
,
config
.
model
,
config
.
name
)
checkpoint_dir
=
os
.
path
.
join
(
run_dir
,
"checkpoint"
)
os
.
makedirs
(
checkpoint_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
checkpoint_dir
):
os
.
makedirs
(
checkpoint_dir
)
# Create tensorboard logger.
tb
=
SummaryWriter
(
os
.
path
.
join
(
run_dir
,
"logs"
))
\
...
...
@@ -102,8 +117,8 @@ def train(config):
if
__name__
==
"__main__"
:
# Create parser.
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train WaveFlow model"
,
formatter_class
=
'default_argparse'
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Train WaveFlow model"
)
#
formatter_class='default_argparse')
add_options_to_parser
(
parser
)
utils
.
add_config_options_to_parser
(
parser
)
...
...
@@ -111,4 +126,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config
=
parser
.
parse_args
()
config
=
utils
.
add_yaml_config
(
config
)
train
(
config
)
parakeet/model
s/waveflow/utils.py
→
example
s/waveflow/utils.py
浏览文件 @
f9d97852
...
...
@@ -2,59 +2,97 @@ import itertools
import
os
import
time
import
jsonargparse
import
argparse
import
ruamel.yaml
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
add_config_options_to_parser
(
parser
):
parser
.
add_argument
(
'--valid_size'
,
type
=
int
,
help
=
"size of the valid dataset"
)
parser
.
add_argument
(
'--segment_length'
,
type
=
int
,
parser
.
add_argument
(
'--valid_size'
,
type
=
int
,
help
=
"size of the valid dataset"
)
parser
.
add_argument
(
'--segment_length'
,
type
=
int
,
help
=
"the length of audio clip for training"
)
parser
.
add_argument
(
'--sample_rate'
,
type
=
int
,
help
=
"sampling rate of audio data file"
)
parser
.
add_argument
(
'--fft_window_shift'
,
type
=
int
,
parser
.
add_argument
(
'--sample_rate'
,
type
=
int
,
help
=
"sampling rate of audio data file"
)
parser
.
add_argument
(
'--fft_window_shift'
,
type
=
int
,
help
=
"the shift of fft window for each frame"
)
parser
.
add_argument
(
'--fft_window_size'
,
type
=
int
,
parser
.
add_argument
(
'--fft_window_size'
,
type
=
int
,
help
=
"the size of fft window for each frame"
)
parser
.
add_argument
(
'--fft_size'
,
type
=
int
,
help
=
"the size of fft filter on each frame"
)
parser
.
add_argument
(
'--mel_bands'
,
type
=
int
,
parser
.
add_argument
(
'--fft_size'
,
type
=
int
,
help
=
"the size of fft filter on each frame"
)
parser
.
add_argument
(
'--mel_bands'
,
type
=
int
,
help
=
"the number of mel bands when calculating mel spectrograms"
)
parser
.
add_argument
(
'--mel_fmin'
,
type
=
float
,
parser
.
add_argument
(
'--mel_fmin'
,
type
=
float
,
help
=
"lowest frequency in calculating mel spectrograms"
)
parser
.
add_argument
(
'--mel_fmax'
,
type
=
float
,
parser
.
add_argument
(
'--mel_fmax'
,
type
=
float
,
help
=
"highest frequency in calculating mel spectrograms"
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
help
=
"seed of random initialization for the model"
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
help
=
"seed of random initialization for the model"
)
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
help
=
"batch size for training"
)
parser
.
add_argument
(
'--test_every'
,
type
=
int
,
help
=
"test interval during training"
)
parser
.
add_argument
(
'--save_every'
,
type
=
int
,
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
help
=
"batch size for training"
)
parser
.
add_argument
(
'--test_every'
,
type
=
int
,
help
=
"test interval during training"
)
parser
.
add_argument
(
'--save_every'
,
type
=
int
,
help
=
"checkpointing interval during training"
)
parser
.
add_argument
(
'--max_iterations'
,
type
=
int
,
help
=
"maximum training iterations"
)
parser
.
add_argument
(
'--max_iterations'
,
type
=
int
,
help
=
"maximum training iterations"
)
parser
.
add_argument
(
'--sigma'
,
type
=
float
,
parser
.
add_argument
(
'--sigma'
,
type
=
float
,
help
=
"standard deviation of the latent Gaussian variable"
)
parser
.
add_argument
(
'--n_flows'
,
type
=
int
,
help
=
"number of flows"
)
parser
.
add_argument
(
'--n_group'
,
type
=
int
,
parser
.
add_argument
(
'--n_flows'
,
type
=
int
,
help
=
"number of flows"
)
parser
.
add_argument
(
'--n_group'
,
type
=
int
,
help
=
"number of adjacent audio samples to squeeze into one column"
)
parser
.
add_argument
(
'--n_layers'
,
type
=
int
,
parser
.
add_argument
(
'--n_layers'
,
type
=
int
,
help
=
"number of conv2d layer in one wavenet-like flow architecture"
)
parser
.
add_argument
(
'--n_channels'
,
type
=
int
,
help
=
"number of residual channels in flow"
)
parser
.
add_argument
(
'--kernel_h'
,
type
=
int
,
parser
.
add_argument
(
'--n_channels'
,
type
=
int
,
help
=
"number of residual channels in flow"
)
parser
.
add_argument
(
'--kernel_h'
,
type
=
int
,
help
=
"height of the kernel in the conv2d layer"
)
parser
.
add_argument
(
'--kernel_w'
,
type
=
int
,
help
=
"width of the kernel in the conv2d layer"
)
parser
.
add_argument
(
'--kernel_w'
,
type
=
int
,
help
=
"width of the kernel in the conv2d layer"
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
"Path to the config file."
)
parser
.
add_argument
(
'--config'
,
action
=
jsonargparse
.
ActionConfigFile
)
def
add_yaml_config
(
config
):
print
(
config
)
with
open
(
config
.
config
,
'rt'
)
as
f
:
yaml_cfg
=
ruamel
.
yaml
.
safe_load
(
f
)
cfg_vars
=
vars
(
config
)
for
k
,
v
in
yaml_cfg
.
items
():
if
k
in
cfg_vars
and
cfg_vars
[
k
]
is
not
None
:
continue
cfg_vars
[
k
]
=
v
return
config
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
...
...
@@ -84,8 +122,12 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
def
load_parameters
(
checkpoint_dir
,
rank
,
model
,
optimizer
=
None
,
iteration
=
None
,
file_path
=
None
):
def
load_parameters
(
checkpoint_dir
,
rank
,
model
,
optimizer
=
None
,
iteration
=
None
,
file_path
=
None
):
if
file_path
is
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
...
...
parakeet/datasets/ljspeech.py
浏览文件 @
f9d97852
...
...
@@ -5,21 +5,26 @@ import librosa
from
..
import
g2p
from
..data.sampler
import
SequentialSampler
,
RandomSampler
,
BatchSampler
from
..data.dataset
import
Dataset
from
..data.dataset
import
Dataset
Mixin
from
..data.datacargo
import
DataCargo
from
..data.batch
import
TextIDBatcher
,
SpecBatcher
class
LJSpeech
(
Dataset
):
class
LJSpeech
(
Dataset
Mixin
):
def
__init__
(
self
,
root
):
super
(
LJSpeech
,
self
).
__init__
()
assert
isinstance
(
root
,
(
str
,
Path
)),
"root should be a string or Path object"
assert
isinstance
(
root
,
(
str
,
Path
)),
"root should be a string or Path object"
self
.
root
=
root
if
isinstance
(
root
,
Path
)
else
Path
(
root
)
self
.
metadata
=
self
.
_prepare_metadata
()
def
_prepare_metadata
(
self
):
csv_path
=
self
.
root
.
joinpath
(
"metadata.csv"
)
metadata
=
pd
.
read_csv
(
csv_path
,
sep
=
"|"
,
header
=
None
,
quoting
=
3
,
metadata
=
pd
.
read_csv
(
csv_path
,
sep
=
"|"
,
header
=
None
,
quoting
=
3
,
names
=
[
"fname"
,
"raw_text"
,
"normalized_text"
])
return
metadata
...
...
@@ -35,7 +40,9 @@ class LJSpeech(Dataset):
wav_path
=
self
.
root
.
joinpath
(
"wavs"
,
fname
+
".wav"
)
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
wav
,
sample_rate
=
librosa
.
load
(
wav_path
,
sr
=
None
)
# we would rather use functor to hold its parameters
wav
,
sample_rate
=
librosa
.
load
(
wav_path
,
sr
=
None
)
# we would rather use functor to hold its parameters
trimed
,
_
=
librosa
.
effects
.
trim
(
wav
)
preemphasized
=
librosa
.
effects
.
preemphasis
(
trimed
)
D
=
librosa
.
stft
(
preemphasized
)
...
...
@@ -50,8 +57,10 @@ class LJSpeech(Dataset):
mel
=
np
.
clip
((
mel
-
ref_db
+
max_db
)
/
max_db
,
1e-8
,
1
)
mel
=
np
.
clip
((
mag
-
ref_db
+
max_db
)
/
max_db
,
1e-8
,
1
)
phonemes
=
np
.
array
(
g2p
.
en
.
text_to_sequence
(
normalized_text
),
dtype
=
np
.
int64
)
return
(
mag
,
mel
,
phonemes
)
# maybe we need to implement it as a map in the future
phonemes
=
np
.
array
(
g2p
.
en
.
text_to_sequence
(
normalized_text
),
dtype
=
np
.
int64
)
return
(
mag
,
mel
,
phonemes
)
# maybe we need to implement it as a map in the future
def
_batch_examples
(
self
,
minibatch
):
mag_batch
=
[]
...
...
@@ -78,5 +87,3 @@ class LJSpeech(Dataset):
def
__len__
(
self
):
return
len
(
self
.
metadata
)
parakeet/models/waveflow/__init__.py
0 → 100644
浏览文件 @
f9d97852
from
parakeet.models.waveflow.waveflow
import
WaveFlow
parakeet/models/waveflow/data.py
浏览文件 @
f9d97852
...
...
@@ -5,10 +5,9 @@ import numpy as np
from
paddle
import
fluid
from
parakeet.datasets
import
ljspeech
from
parakeet.data
import
dataset
from
parakeet.data.batch
import
SpecBatcher
,
WavBatcher
from
parakeet.data.datacargo
import
DataCargo
from
parakeet.data.sampler
import
DistributedSampler
,
BatchSampler
from
parakeet.data
import
SpecBatcher
,
WavBatcher
from
parakeet.data
import
DataCargo
,
DatasetMixin
from
parakeet.data
import
DistributedSampler
,
BatchSampler
from
scipy.io.wavfile
import
read
...
...
@@ -27,7 +26,7 @@ class Dataset(ljspeech.LJSpeech):
return
audio
class
Subset
(
dataset
.
Dataset
):
class
Subset
(
DatasetMixin
):
def
__init__
(
self
,
dataset
,
indices
,
valid
):
self
.
dataset
=
dataset
self
.
indices
=
indices
...
...
@@ -36,14 +35,14 @@ class Subset(dataset.Dataset):
def
get_mel
(
self
,
audio
):
spectrogram
=
librosa
.
core
.
stft
(
audio
,
n_fft
=
self
.
config
.
fft_size
,
audio
,
n_fft
=
self
.
config
.
fft_size
,
hop_length
=
self
.
config
.
fft_window_shift
,
win_length
=
self
.
config
.
fft_window_size
)
spectrogram_magnitude
=
np
.
abs
(
spectrogram
)
# mel_filter_bank shape: [n_mels, 1 + n_fft/2]
mel_filter_bank
=
librosa
.
filters
.
mel
(
sr
=
self
.
config
.
sample_rate
,
mel_filter_bank
=
librosa
.
filters
.
mel
(
sr
=
self
.
config
.
sample_rate
,
n_fft
=
self
.
config
.
fft_size
,
n_mels
=
self
.
config
.
mel_bands
,
fmin
=
self
.
config
.
mel_fmin
,
...
...
@@ -70,10 +69,11 @@ class Subset(dataset.Dataset):
if
audio
.
shape
[
0
]
>=
segment_length
:
max_audio_start
=
audio
.
shape
[
0
]
-
segment_length
audio_start
=
random
.
randint
(
0
,
max_audio_start
)
audio
=
audio
[
audio_start
:
(
audio_start
+
segment_length
)]
audio
=
audio
[
audio_start
:
(
audio_start
+
segment_length
)]
else
:
audio
=
np
.
pad
(
audio
,
(
0
,
segment_length
-
audio
.
shape
[
0
]),
mode
=
'constant'
,
constant_values
=
0
)
mode
=
'constant'
,
constant_values
=
0
)
# Normalize audio to the [-1, 1] range.
audio
=
audio
.
astype
(
np
.
float32
)
/
32768.0
...
...
@@ -112,8 +112,8 @@ class LJSpeech:
sampler
=
DistributedSampler
(
len
(
trainset
),
nranks
,
rank
)
total_bs
=
config
.
batch_size
assert
total_bs
%
nranks
==
0
train_sampler
=
BatchSampler
(
sampler
,
total_bs
//
nranks
,
drop_last
=
True
)
train_sampler
=
BatchSampler
(
sampler
,
total_bs
//
nranks
,
drop_last
=
True
)
trainloader
=
DataCargo
(
trainset
,
batch_sampler
=
train_sampler
)
trainreader
=
fluid
.
io
.
PyReader
(
capacity
=
50
,
return_list
=
True
)
...
...
parakeet/models/waveflow/waveflow.py
浏览文件 @
f9d97852
...
...
@@ -8,13 +8,18 @@ from paddle import fluid
from
scipy.io.wavfile
import
write
import
utils
from
data
import
LJSpeech
from
waveflow_modules
import
WaveFlowLoss
,
WaveFlowModule
from
.
data
import
LJSpeech
from
.
waveflow_modules
import
WaveFlowLoss
,
WaveFlowModule
class
WaveFlow
():
def
__init__
(
self
,
config
,
checkpoint_dir
,
parallel
=
False
,
rank
=
0
,
nranks
=
1
,
tb_logger
=
None
):
def
__init__
(
self
,
config
,
checkpoint_dir
,
parallel
=
False
,
rank
=
0
,
nranks
=
1
,
tb_logger
=
None
):
self
.
config
=
config
self
.
checkpoint_dir
=
checkpoint_dir
self
.
parallel
=
parallel
...
...
@@ -28,7 +33,7 @@ class WaveFlow():
self
.
trainloader
=
dataset
.
trainloader
self
.
validloader
=
dataset
.
validloader
waveflow
=
WaveFlowModule
(
"waveflow"
,
config
)
waveflow
=
WaveFlowModule
(
config
)
# Dry run once to create and initalize all necessary parameters.
audio
=
dg
.
to_variable
(
np
.
random
.
randn
(
1
,
16000
).
astype
(
np
.
float32
))
...
...
@@ -38,11 +43,15 @@ class WaveFlow():
if
training
:
optimizer
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
config
.
learning_rate
)
learning_rate
=
config
.
learning_rate
,
parameter_list
=
waveflow
.
parameters
())
# Load parameters.
utils
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
rank
,
waveflow
,
optimizer
,
utils
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
rank
,
waveflow
,
optimizer
,
iteration
=
config
.
iteration
,
file_path
=
config
.
checkpoint
)
print
(
"Rank {}: checkpoint loaded."
.
format
(
self
.
rank
))
...
...
@@ -58,7 +67,10 @@ class WaveFlow():
else
:
# Load parameters.
utils
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
rank
,
waveflow
,
utils
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
rank
,
waveflow
,
iteration
=
config
.
iteration
,
file_path
=
config
.
checkpoint
)
print
(
"Rank {}: checkpoint loaded."
.
format
(
self
.
rank
))
...
...
@@ -83,7 +95,8 @@ class WaveFlow():
else
:
loss
.
backward
()
self
.
optimizer
.
minimize
(
loss
,
parameter_list
=
self
.
waveflow
.
parameters
())
self
.
optimizer
.
minimize
(
loss
,
parameter_list
=
self
.
waveflow
.
parameters
())
self
.
waveflow
.
clear_gradients
()
graph_time
=
time
.
time
()
...
...
@@ -155,8 +168,8 @@ class WaveFlow():
audio
=
audio
[
0
]
audio_time
=
audio
.
shape
[
0
]
/
self
.
config
.
sample_rate
print
(
"audio time {:.4f}, synthesis time {:.4f}"
.
format
(
audio_time
,
syn_time
))
print
(
"audio time {:.4f}, synthesis time {:.4f}"
.
format
(
audio_time
,
syn_time
))
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
audio
=
audio
.
numpy
()
*
32768.0
...
...
@@ -180,8 +193,8 @@ class WaveFlow():
syn_time
=
time
.
time
()
-
start_time
audio_time
=
audio
.
shape
[
1
]
*
batch_size
/
self
.
config
.
sample_rate
print
(
"audio time {:.4f}, synthesis time {:.4f}"
.
format
(
audio_time
,
syn_time
))
print
(
"audio time {:.4f}, synthesis time {:.4f}"
.
format
(
audio_time
,
syn_time
))
print
(
"{} X real-time"
.
format
(
audio_time
/
syn_time
))
def
save
(
self
,
iteration
):
...
...
parakeet/models/waveflow/waveflow_modules.py
浏览文件 @
f9d97852
...
...
@@ -3,22 +3,23 @@ import itertools
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
from
paddle
import
fluid
from
parakeet.modules
import
conv
,
modules
,
weight_norm
from
parakeet.modules
import
weight_norm
def
set_param_attr
(
layer
,
c_in
=
1
):
if
isinstance
(
layer
,
(
weight_norm
.
Conv2DTranspose
,
weight_norm
.
Conv2D
))
:
k
=
np
.
sqrt
(
1.0
/
(
c_in
*
np
.
prod
(
layer
.
_
filter_size
)))
def
get_param_attr
(
layer_type
,
filter_size
,
c_in
=
1
):
if
layer_type
==
"weight_norm"
:
k
=
np
.
sqrt
(
1.0
/
(
c_in
*
np
.
prod
(
filter_size
)))
weight_init
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
k
,
high
=
k
)
bias_init
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
k
,
high
=
k
)
elif
isinstance
(
layer
,
dg
.
Conv2D
)
:
elif
layer_type
==
"common"
:
weight_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
bias_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
else
:
raise
TypeError
(
"Unsupported layer type."
)
layer
.
_param_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
layer
.
_bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
return
param_attr
,
bias_attr
def
unfold
(
x
,
n_group
):
...
...
@@ -48,20 +49,23 @@ class WaveFlowLoss:
class
Conditioner
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
Conditioner
,
self
).
__init__
(
name_scope
)
def
__init__
(
self
):
super
(
Conditioner
,
self
).
__init__
()
upsample_factors
=
[
16
,
16
]
self
.
upsample_conv2d
=
[]
for
s
in
upsample_factors
:
in_channel
=
1
conv_trans2d
=
modules
.
Conv2DTranspose
(
self
.
full_name
(),
param_attr
,
bias_attr
=
get_param_attr
(
"weight_norm"
,
(
3
,
2
*
s
),
c_in
=
in_channel
)
conv_trans2d
=
weight_norm
.
Conv2DTranspose
(
num_channels
=
in_channel
,
num_filters
=
1
,
filter_size
=
(
3
,
2
*
s
),
padding
=
(
1
,
s
//
2
),
stride
=
(
1
,
s
))
set_param_attr
(
conv_trans2d
,
c_in
=
in_channel
)
stride
=
(
1
,
s
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
self
.
upsample_conv2d
.
append
(
conv_trans2d
)
for
i
,
layer
in
enumerate
(
self
.
upsample_conv2d
):
...
...
@@ -86,8 +90,8 @@ class Conditioner(dg.Layer):
class
Flow
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
config
):
super
(
Flow
,
self
).
__init__
(
name_scope
)
def
__init__
(
self
,
config
):
super
(
Flow
,
self
).
__init__
()
self
.
n_layers
=
config
.
n_layers
self
.
n_channels
=
config
.
n_channels
self
.
kernel_h
=
config
.
kernel_h
...
...
@@ -95,27 +99,34 @@ class Flow(dg.Layer):
# Transform audio: [batch, 1, n_group, time/n_group]
# => [batch, n_channels, n_group, time/n_group]
param_attr
,
bias_attr
=
get_param_attr
(
"weight_norm"
,
(
1
,
1
),
c_in
=
1
)
self
.
start
=
weight_norm
.
Conv2D
(
self
.
full_name
()
,
num_channels
=
1
,
num_filters
=
self
.
n_channels
,
filter_size
=
(
1
,
1
))
set_param_attr
(
self
.
start
,
c_in
=
1
)
filter_size
=
(
1
,
1
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
# output shape: [batch, 2, n_group, time/n_group]
param_attr
,
bias_attr
=
get_param_attr
(
"common"
,
(
1
,
1
),
c_in
=
self
.
n_channels
)
self
.
end
=
dg
.
Conv2D
(
self
.
full_name
()
,
num_channels
=
self
.
n_channels
,
num_filters
=
2
,
filter_size
=
(
1
,
1
))
set_param_attr
(
self
.
end
)
filter_size
=
(
1
,
1
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
dilation_dict
=
{
8
:
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
dilation_dict
=
{
8
:
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
16
:
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
32
:
[
1
,
2
,
4
,
1
,
2
,
4
,
1
,
2
],
64
:
[
1
,
2
,
4
,
8
,
16
,
1
,
2
,
4
],
128
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
1
]}
128
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
1
]
}
self
.
dilation_h_list
=
dilation_dict
[
config
.
n_group
]
self
.
in_layers
=
[]
...
...
@@ -123,32 +134,42 @@ class Flow(dg.Layer):
self
.
res_skip_layers
=
[]
for
i
in
range
(
self
.
n_layers
):
dilation_h
=
self
.
dilation_h_list
[
i
]
dilation_w
=
2
**
i
dilation_w
=
2
**
i
param_attr
,
bias_attr
=
get_param_attr
(
"weight_norm"
,
(
self
.
kernel_h
,
self
.
kernel_w
),
c_in
=
self
.
n_channels
)
in_layer
=
weight_norm
.
Conv2D
(
self
.
full_name
()
,
num_channels
=
self
.
n_channels
,
num_filters
=
2
*
self
.
n_channels
,
filter_size
=
(
self
.
kernel_h
,
self
.
kernel_w
),
dilation
=
(
dilation_h
,
dilation_w
))
set_param_attr
(
in_layer
,
c_in
=
self
.
n_channels
)
dilation
=
(
dilation_h
,
dilation_w
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
self
.
in_layers
.
append
(
in_layer
)
param_attr
,
bias_attr
=
get_param_attr
(
"weight_norm"
,
(
1
,
1
),
c_in
=
config
.
mel_bands
)
cond_layer
=
weight_norm
.
Conv2D
(
self
.
full_name
()
,
num_channels
=
config
.
mel_bands
,
num_filters
=
2
*
self
.
n_channels
,
filter_size
=
(
1
,
1
))
set_param_attr
(
cond_layer
,
c_in
=
config
.
mel_bands
)
filter_size
=
(
1
,
1
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
self
.
cond_layers
.
append
(
cond_layer
)
if
i
<
self
.
n_layers
-
1
:
res_skip_channels
=
2
*
self
.
n_channels
else
:
res_skip_channels
=
self
.
n_channels
param_attr
,
bias_attr
=
get_param_attr
(
"weight_norm"
,
(
1
,
1
),
c_in
=
self
.
n_channels
)
res_skip_layer
=
weight_norm
.
Conv2D
(
self
.
full_name
()
,
num_channels
=
self
.
n_channels
,
num_filters
=
res_skip_channels
,
filter_size
=
(
1
,
1
))
set_param_attr
(
res_skip_layer
,
c_in
=
self
.
n_channels
)
filter_size
=
(
1
,
1
),
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
self
.
res_skip_layers
.
append
(
res_skip_layer
)
self
.
add_sublayer
(
"in_layer_{}"
.
format
(
i
),
in_layer
)
...
...
@@ -162,14 +183,14 @@ class Flow(dg.Layer):
for
i
in
range
(
self
.
n_layers
):
dilation_h
=
self
.
dilation_h_list
[
i
]
dilation_w
=
2
**
i
dilation_w
=
2
**
i
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top
,
pad_bottom
=
(
self
.
kernel_h
-
1
)
*
dilation_h
,
0
pad_left
=
pad_right
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
audio_pad
=
fluid
.
layers
.
pad2d
(
audio
,
paddings
=
[
pad_top
,
pad_bottom
,
pad_left
,
pad_right
])
pad_left
=
pad_right
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
audio_pad
=
fluid
.
layers
.
pad2d
(
audio
,
paddings
=
[
pad_top
,
pad_bottom
,
pad_left
,
pad_right
])
hidden
=
self
.
in_layers
[
i
](
audio_pad
)
cond_hidden
=
self
.
cond_layers
[
i
](
mel
)
...
...
@@ -196,7 +217,7 @@ class Flow(dg.Layer):
for
i
in
range
(
self
.
n_layers
):
dilation_h
=
self
.
dilation_h_list
[
i
]
dilation_w
=
2
**
i
dilation_w
=
2
**
i
state_size
=
dilation_h
*
(
self
.
kernel_h
-
1
)
queue
=
queues
[
i
]
...
...
@@ -206,7 +227,7 @@ class Flow(dg.Layer):
queue
.
append
(
fluid
.
layers
.
zeros_like
(
audio
))
state
=
queue
[
0
:
state_size
]
state
=
fluid
.
layers
.
concat
(
[
*
state
,
audio
],
axis
=
2
)
state
=
fluid
.
layers
.
concat
(
state
+
[
audio
],
axis
=
2
)
queue
.
pop
(
0
)
queue
.
append
(
audio
)
...
...
@@ -214,10 +235,10 @@ class Flow(dg.Layer):
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top
,
pad_bottom
=
0
,
0
pad_left
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
pad_right
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
state
=
fluid
.
layers
.
pad2d
(
state
,
paddings
=
[
pad_top
,
pad_bottom
,
pad_left
,
pad_right
])
pad_left
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
pad_right
=
int
((
self
.
kernel_w
-
1
)
*
dilation_w
/
2
)
state
=
fluid
.
layers
.
pad2d
(
state
,
paddings
=
[
pad_top
,
pad_bottom
,
pad_left
,
pad_right
])
hidden
=
self
.
in_layers
[
i
](
state
)
cond_hidden
=
self
.
cond_layers
[
i
](
mel
)
...
...
@@ -241,18 +262,18 @@ class Flow(dg.Layer):
class
WaveFlowModule
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
config
):
super
(
WaveFlowModule
,
self
).
__init__
(
name_scope
)
def
__init__
(
self
,
config
):
super
(
WaveFlowModule
,
self
).
__init__
()
self
.
n_flows
=
config
.
n_flows
self
.
n_group
=
config
.
n_group
self
.
n_layers
=
config
.
n_layers
assert
self
.
n_group
%
2
==
0
assert
self
.
n_flows
%
2
==
0
self
.
conditioner
=
Conditioner
(
self
.
full_name
()
)
self
.
conditioner
=
Conditioner
()
self
.
flows
=
[]
for
i
in
range
(
self
.
n_flows
):
flow
=
Flow
(
self
.
full_name
(),
config
)
flow
=
Flow
(
config
)
self
.
flows
.
append
(
flow
)
self
.
add_sublayer
(
"flow_{}"
.
format
(
i
),
flow
)
...
...
@@ -284,7 +305,6 @@ class WaveFlowModule(dg.Layer):
audio
=
fluid
.
layers
.
transpose
(
unfold
(
audio
,
self
.
n_group
),
[
0
,
2
,
1
])
# [bs, 1, n_group, time/n_group]
audio
=
fluid
.
layers
.
unsqueeze
(
audio
,
1
)
log_s_list
=
[]
for
i
in
range
(
self
.
n_flows
):
inputs
=
audio
[:,
:,
:
-
1
,
:]
...
...
@@ -305,7 +325,6 @@ class WaveFlowModule(dg.Layer):
mel
=
fluid
.
layers
.
stack
(
mel_slices
,
axis
=
2
)
z
=
fluid
.
layers
.
squeeze
(
audio
,
[
1
])
return
z
,
log_s_list
def
synthesize
(
self
,
mel
,
sigma
=
1.0
):
...
...
@@ -331,7 +350,7 @@ class WaveFlowModule(dg.Layer):
for
h
in
range
(
1
,
self
.
n_group
):
inputs
=
audio_h
conds
=
mel
[:,
:,
h
:(
h
+
1
),
:]
conds
=
mel
[:,
:,
h
:(
h
+
1
),
:]
outputs
=
self
.
flows
[
i
].
infer
(
inputs
,
conds
,
queues
)
log_s
=
outputs
[:,
0
:
1
,
:,
:]
...
...
parakeet/modules/weight_norm.py
浏览文件 @
f9d97852
...
...
@@ -40,8 +40,8 @@ def norm_except(param, dim, power):
def
compute_weight
(
v
,
g
,
dim
,
power
):
assert
len
(
g
.
shape
)
==
1
,
"magnitude should be a vector"
v_normalized
=
F
.
elementwise_div
(
v
,
(
norm_except
(
v
,
dim
,
power
)
+
1e-12
),
axis
=
dim
)
v_normalized
=
F
.
elementwise_div
(
v
,
(
norm_except
(
v
,
dim
,
power
)
+
1e-12
),
axis
=
dim
)
weight
=
F
.
elementwise_mul
(
v_normalized
,
g
,
axis
=
dim
)
return
weight
...
...
@@ -63,20 +63,21 @@ class WeightNormWrapper(dg.Layer):
original_weight
=
getattr
(
layer
,
param_name
)
self
.
add_parameter
(
w_v
,
self
.
create_parameter
(
shape
=
original_weight
.
shape
,
dtype
=
original_weight
.
dtype
))
self
.
create_parameter
(
shape
=
original_weight
.
shape
,
dtype
=
original_weight
.
dtype
))
F
.
assign
(
original_weight
,
getattr
(
self
,
w_v
))
delattr
(
layer
,
param_name
)
temp
=
norm_except
(
getattr
(
self
,
w_v
),
self
.
dim
,
self
.
power
)
self
.
add_parameter
(
w_g
,
self
.
create_parameter
(
shape
=
temp
.
shape
,
dtype
=
temp
.
dtype
))
w_g
,
self
.
create_parameter
(
shape
=
temp
.
shape
,
dtype
=
temp
.
dtype
))
F
.
assign
(
temp
,
getattr
(
self
,
w_g
))
# also set this when setting up
setattr
(
self
.
layer
,
self
.
param_name
,
compute_weight
(
getattr
(
self
,
w_v
),
getattr
(
self
,
w_g
),
self
.
dim
,
self
.
power
))
setattr
(
self
.
layer
,
self
.
param_name
,
compute_weight
(
getattr
(
self
,
w_v
)
,
getattr
(
self
,
w_g
),
self
.
dim
,
self
.
power
))
self
.
weigth_norm_applied
=
True
...
...
@@ -84,10 +85,10 @@ class WeightNormWrapper(dg.Layer):
def
hook
(
self
):
w_v
=
self
.
param_name
+
"_v"
w_g
=
self
.
param_name
+
"_g"
setattr
(
self
.
layer
,
self
.
param_name
,
compute_weight
(
getattr
(
self
,
w_v
),
getattr
(
self
,
w_g
),
self
.
dim
,
self
.
power
))
setattr
(
self
.
layer
,
self
.
param_name
,
compute_weight
(
getattr
(
self
,
w_v
)
,
getattr
(
self
,
w_g
),
self
.
dim
,
self
.
power
))
def
remove_weight_norm
(
self
):
self
.
hook
()
...
...
@@ -112,6 +113,13 @@ class WeightNormWrapper(dg.Layer):
return
getattr
(
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"layer"
],
key
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
"_param_attr"
or
name
==
"_bias_attr"
:
print
(
name
)
setattr
(
self
.
layer
,
name
,
value
)
else
:
super
().
__setattr__
(
name
,
value
)
def
Linear
(
input_dim
,
output_dim
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录