From 3ce29513db1098b605a0c9bfb29349613cf51a45 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 6 Jul 2020 19:43:49 +0800 Subject: [PATCH] only save ckpt in rank0 for Transformer --- model_zoo/Transformer/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/model_zoo/Transformer/train.py b/model_zoo/Transformer/train.py index 23c0eb78f..ffd6b8c71 100644 --- a/model_zoo/Transformer/train.py +++ b/model_zoo/Transformer/train.py @@ -147,10 +147,11 @@ def run_transformer_train(): callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] if args.enable_save_ckpt == "true": - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, - keep_checkpoint_max=args.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) - callbacks.append(ckpoint_cb) + if device_num == 1 or (device_num > 1 and rank_id == 0): + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, + keep_checkpoint_max=args.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) + callbacks.append(ckpoint_cb) if args.enable_lossscale == "true": scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, -- GitLab