提交 4dd9a273 编写于 作者: L liuyibing01

Merge branch 'fix-name' into 'master'

modified the name of vocoder

See merge request !17
...@@ -42,13 +42,13 @@ def synthesis(text_input, args): ...@@ -42,13 +42,13 @@ def synthesis(text_input, args):
with dg.guard(place): with dg.guard(place):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = TransformerTTS(cfg) model = TransformerTTS(cfg)
model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "nostop_token/transformer"))) model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer")))
model.eval() model.eval()
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model_postnet = Vocoder(cfg, args.batch_size) model_vocoder = Vocoder(cfg, args.batch_size)
model_postnet.set_dict(load_checkpoint(str(args.postnet_step), os.path.join(args.checkpoint_path, "postnet"))) model_vocoder.set_dict(load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder")))
model_postnet.eval() model_vocoder.eval()
# init input # init input
text = np.asarray(text_to_sequence(text_input)) text = np.asarray(text_to_sequence(text_input))
text = fluid.layers.unsqueeze(dg.to_variable(text),[0]) text = fluid.layers.unsqueeze(dg.to_variable(text),[0])
...@@ -64,7 +64,7 @@ def synthesis(text_input, args): ...@@ -64,7 +64,7 @@ def synthesis(text_input, args):
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0]) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1) mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
mag_pred = model_postnet(postnet_pred) mag_pred = model_vocoder(postnet_pred)
_ljspeech_processor = audio.AudioProcessor( _ljspeech_processor = audio.AudioProcessor(
sample_rate=cfg['audio']['sr'], sample_rate=cfg['audio']['sr'],
......
...@@ -38,7 +38,7 @@ def main(args): ...@@ -38,7 +38,7 @@ def main(args):
if not os.path.exists(args.log_dir): if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir) os.mkdir(args.log_dir)
path = os.path.join(args.log_dir,'postnet') path = os.path.join(args.log_dir,'vocoder')
writer = SummaryWriter(path) if local_rank == 0 else None writer = SummaryWriter(path) if local_rank == 0 else None
...@@ -51,7 +51,7 @@ def main(args): ...@@ -51,7 +51,7 @@ def main(args):
if args.checkpoint_path is not None: if args.checkpoint_path is not None:
model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "postnet")) model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder"))
model.set_dict(model_dict) model.set_dict(model_dict)
optimizer.set_dict(opti_dict) optimizer.set_dict(opti_dict)
global_step = args.vocoder_step global_step = args.vocoder_step
...@@ -92,7 +92,7 @@ def main(args): ...@@ -92,7 +92,7 @@ def main(args):
if global_step % args.save_step == 0: if global_step % args.save_step == 0:
if not os.path.exists(args.save_path): if not os.path.exists(args.save_path):
os.mkdir(args.save_path) os.mkdir(args.save_path)
save_path = os.path.join(args.save_path,'postnet/%d' % global_step) save_path = os.path.join(args.save_path,'vocoder/%d' % global_step)
dg.save_dygraph(model.state_dict(), save_path) dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path) dg.save_dygraph(optimizer.state_dict(), save_path)
...@@ -100,7 +100,7 @@ def main(args): ...@@ -100,7 +100,7 @@ def main(args):
writer.close() writer.close()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train postnet model") parser = argparse.ArgumentParser(description="Train vocoder model")
add_config_options_to_parser(parser) add_config_options_to_parser(parser)
args = parser.parse_args() args = parser.parse_args()
# Print the whole config setting. # Print the whole config setting.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册