提交 3ce29513 编写于 作者: Y yuchaojie

only save ckpt in rank0 for Transformer

上级 3d377c51
...@@ -147,6 +147,7 @@ def run_transformer_train(): ...@@ -147,6 +147,7 @@ def run_transformer_train():
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
if args.enable_save_ckpt == "true": if args.enable_save_ckpt == "true":
if device_num == 1 or (device_num > 1 and rank_id == 0):
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
keep_checkpoint_max=args.save_checkpoint_num) keep_checkpoint_max=args.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册