提交 eca6c9c6 编写于 作者: M moran

Fix wizard template module to fit new operator API

上级 b7e7681e
......@@ -60,14 +60,14 @@ if __name__ == "__main__":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
......
......@@ -60,7 +60,7 @@ if __name__ == "__main__":
raise ValueError('Distribute running is no supported on %s' % args.device_target)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
data_path = args.dataset_path
do_train = True
......
......@@ -65,7 +65,7 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
init()
......@@ -73,7 +73,7 @@ if __name__ == '__main__':
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册