提交 9b86f200 编写于 作者: L lifuchen

fix a bug of transformertts when use data parallel.

上级 6428ce54
...@@ -94,10 +94,16 @@ def main(args): ...@@ -94,10 +94,16 @@ def main(args):
if args.stop_token: if args.stop_token:
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
writer.add_scalars('alphas', { if args.use_data_parallel:
'encoder_alpha':model.encoder.alpha.numpy(), writer.add_scalars('alphas', {
'decoder_alpha':model.decoder.alpha.numpy(), 'encoder_alpha':model._layers.encoder.alpha.numpy(),
}, global_step) 'decoder_alpha':model._layers.decoder.alpha.numpy(),
}, global_step)
else:
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step) writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
...@@ -144,4 +150,4 @@ if __name__ =='__main__': ...@@ -144,4 +150,4 @@ if __name__ =='__main__':
args = parser.parse_args() args = parser.parse_args()
# Print the whole config setting. # Print the whole config setting.
pprint(args) pprint(args)
main(args) main(args)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册