提交 ed9cf203 编写于 作者: Y yuchaojie

add nccl default allreduce fusion group

上级 16079e63
...@@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext ...@@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
_MAX_GROUP_NAME_LEN = 127 _MAX_GROUP_NAME_LEN = 127
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
class _AutoParallelContext: class _AutoParallelContext:
...@@ -267,7 +269,7 @@ class _AutoParallelContext: ...@@ -267,7 +269,7 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_parameter_broadcast_is_set() return self._context_handle.get_parameter_broadcast_is_set()
def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): def set_all_reduce_fusion_split_indices(self, indices, group=""):
""" """
Set allreduce fusion strategy by parameters indices. Set allreduce fusion strategy by parameters indices.
...@@ -294,11 +296,17 @@ class _AutoParallelContext: ...@@ -294,11 +296,17 @@ class _AutoParallelContext:
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_idx(indices) _set_fusion_strategy_by_idx(indices)
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): def get_all_reduce_fusion_split_indices(self, group=""):
""" """
Get allreduce fusion split indices. Get allreduce fusion split indices.
...@@ -318,9 +326,15 @@ class _AutoParallelContext: ...@@ -318,9 +326,15 @@ class _AutoParallelContext:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_indices(group) return self._context_handle.get_all_reduce_fusion_split_indices(group)
def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
""" """
Set allreduce fusion strategy by parameters data sizes. Set allreduce fusion strategy by parameters data sizes.
...@@ -347,11 +361,17 @@ class _AutoParallelContext: ...@@ -347,11 +361,17 @@ class _AutoParallelContext:
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes) _set_fusion_strategy_by_size(sizes)
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): def get_all_reduce_fusion_split_sizes(self, group=""):
""" """
Get allreduce fusion split sizes. Get allreduce fusion split sizes.
...@@ -371,6 +391,12 @@ class _AutoParallelContext: ...@@ -371,6 +391,12 @@ class _AutoParallelContext:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_sizes(group) return self._context_handle.get_all_reduce_fusion_split_sizes(group)
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册