提交 cf9f8818 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2201 Add save&load ckpt path for distribution training

Merge pull request !2201 from chenweifeng/gpu_bert
......@@ -68,7 +68,8 @@ def run_pretrain():
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
"default is 1000.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
......@@ -81,7 +82,7 @@ def run_pretrain():
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
ckpt_save_dir = args_opt.checkpoint_path
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':
D.init('hccl')
......@@ -91,7 +92,7 @@ def run_pretrain():
D.init('nccl')
device_num = D.get_group_size()
rank = D.get_rank()
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/'
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
......@@ -150,8 +151,8 @@ def run_pretrain():
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(netwithloss, param_dict)
if args_opt.enable_lossscale == "true":
......
......@@ -64,7 +64,7 @@ do
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=100 \
--checkpoint_path="" \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
......
......@@ -36,7 +36,7 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
......
......@@ -38,7 +38,7 @@ python run_pretrain.py \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
......
......@@ -40,7 +40,8 @@ python run_pretrain.py \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--load_checkpoint_path="" \
--save_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册