From 9b86f2008d4e4ad846bd33fd0dc33490327fc42f Mon Sep 17 00:00:00 2001 From: lifuchen Date: Wed, 19 Feb 2020 12:55:15 +0000 Subject: [PATCH] fix a bug of transformertts when use data parallel. --- examples/transformer_tts/train_transformer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index d258209..02284f7 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -94,10 +94,16 @@ def main(args): if args.stop_token: writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) - writer.add_scalars('alphas', { - 'encoder_alpha':model.encoder.alpha.numpy(), - 'decoder_alpha':model.decoder.alpha.numpy(), - }, global_step) + if args.use_data_parallel: + writer.add_scalars('alphas', { + 'encoder_alpha':model._layers.encoder.alpha.numpy(), + 'decoder_alpha':model._layers.decoder.alpha.numpy(), + }, global_step) + else: + writer.add_scalars('alphas', { + 'encoder_alpha':model.encoder.alpha.numpy(), + 'decoder_alpha':model.decoder.alpha.numpy(), + }, global_step) writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step) @@ -144,4 +150,4 @@ if __name__ =='__main__': args = parser.parse_args() # Print the whole config setting. pprint(args) - main(args) \ No newline at end of file + main(args) -- GitLab