提交 8b051486 编写于 作者: L liuyibing01

Update synthesis script for waveflow

上级 7635493a
...@@ -2,13 +2,13 @@ import os ...@@ -2,13 +2,13 @@ import os
import random import random
from pprint import pprint from pprint import pprint
import jsonargparse import argparse
import numpy as np import numpy as np
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import fluid from paddle import fluid
import utils import utils
from waveflow import WaveFlow from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser): def add_options_to_parser(parser):
...@@ -53,7 +53,7 @@ def add_options_to_parser(parser): ...@@ -53,7 +53,7 @@ def add_options_to_parser(parser):
def synthesize(config): def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config)) pprint(vars(config))
# Get checkpoint directory path. # Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name) run_dir = os.path.join("runs", config.model, config.name)
...@@ -90,9 +90,8 @@ def synthesize(config): ...@@ -90,9 +90,8 @@ def synthesize(config):
if __name__ == "__main__": if __name__ == "__main__":
# Create parser. # Create parser.
parser = jsonargparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model", description="Synthesize audio using WaveNet model")
formatter_class='default_argparse')
add_options_to_parser(parser) add_options_to_parser(parser)
utils.add_config_options_to_parser(parser) utils.add_config_options_to_parser(parser)
...@@ -100,4 +99,5 @@ if __name__ == "__main__": ...@@ -100,4 +99,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field, # For conflicting updates to the same field,
# the preceding update will be overwritten by the following one. # the preceding update will be overwritten by the following one.
config = parser.parse_args() config = parser.parse_args()
config = utils.add_yaml_config(config)
synthesize(config) synthesize(config)
...@@ -84,7 +84,6 @@ def add_config_options_to_parser(parser): ...@@ -84,7 +84,6 @@ def add_config_options_to_parser(parser):
def add_yaml_config(config): def add_yaml_config(config):
print(config)
with open(config.config, 'rt') as f: with open(config.config, 'rt') as f:
yaml_cfg = ruamel.yaml.safe_load(f) yaml_cfg = ruamel.yaml.safe_load(f)
cfg_vars = vars(config) cfg_vars = vars(config)
......
...@@ -152,7 +152,8 @@ class WaveFlow(): ...@@ -152,7 +152,8 @@ class WaveFlow():
sample = config.sample sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration) output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True) if not os.path.exists(output):
os.makedirs(output)
mels_list = [mels for _, mels in self.validloader()] mels_list = [mels for _, mels in self.validloader()]
if sample is not None: if sample is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册