提交 993a28bc 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4275 add allredcue grouping for resnet gpu version

Merge pull request !4275 from yuchaojie/add_allreduce_group_for_resnet_gpu
...@@ -275,7 +275,7 @@ class _AutoParallelContext: ...@@ -275,7 +275,7 @@ class _AutoParallelContext:
Args: Args:
indices (list): Indices list. indices (list): Indices list.
group (str): The hccl communication group. group (str): The communication group of hccl/nccl.
Raises: Raises:
TypeError: If type of indices item is not int. TypeError: If type of indices item is not int.
...@@ -311,7 +311,7 @@ class _AutoParallelContext: ...@@ -311,7 +311,7 @@ class _AutoParallelContext:
Get allreduce fusion split indices. Get allreduce fusion split indices.
Args: Args:
group (str): The hccl communication group. group (str): The communication group of hccl/nccl.
Returns: Returns:
Return split sizes list according to the group. Return split sizes list according to the group.
...@@ -340,7 +340,7 @@ class _AutoParallelContext: ...@@ -340,7 +340,7 @@ class _AutoParallelContext:
Args: Args:
sizes (list): Sizes list. sizes (list): Sizes list.
group (str): The hccl communication group. group (str): The communication group of hccl/nccl.
Raises: Raises:
TypeError: If type of sizes item is not int. TypeError: If type of sizes item is not int.
...@@ -376,7 +376,7 @@ class _AutoParallelContext: ...@@ -376,7 +376,7 @@ class _AutoParallelContext:
Get allreduce fusion split sizes. Get allreduce fusion split sizes.
Args: Args:
group (str): The hccl communication group. group (str): The communication group of hccl/nccl.
Returns: Returns:
Return split sizes list according to the group. Return split sizes list according to the group.
......
...@@ -44,7 +44,7 @@ ImageNet2012 ...@@ -44,7 +44,7 @@ ImageNet2012
├── run_distribute_train.sh # launch distributed training(8 pcs) ├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_parameter_server_train.sh # launch Ascend parameter server training(8 pcs) ├── run_parameter_server_train.sh # launch Ascend parameter server training(8 pcs)
├── run_eval.sh # launch evaluation ├── run_eval.sh # launch evaluation
── run_standalone_train.sh # launch standalone training(1 pcs) ── run_standalone_train.sh # launch standalone training(1 pcs)
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs) ├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs) ├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs)
├── run_eval_gpu.sh # launch gpu evaluation ├── run_eval_gpu.sh # launch gpu evaluation
......
...@@ -81,9 +81,11 @@ if __name__ == '__main__': ...@@ -81,9 +81,11 @@ if __name__ == '__main__':
init() init()
# GPU target # GPU target
else: else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
if args_opt.net == "resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
init("nccl")
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset # create dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册