提交 789edcb2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!654 change get_group to internal interface

Merge pull request !654 from lichen/change_get_group_to_internal_interface
...@@ -17,12 +17,12 @@ Collective communication interface. ...@@ -17,12 +17,12 @@ Collective communication interface.
""" """
from .management import GlobalComm, init, release, get_rank, get_group_size, get_world_rank_from_group_rank, \ from .management import GlobalComm, init, release, get_rank, get_group_size, get_world_rank_from_group_rank, \
get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, get_group, \ get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
get_local_rank, get_local_rank_size, destroy_group get_local_rank, get_local_rank_size, destroy_group
__all__ = [ __all__ = [
"GlobalComm", "init", "release", "get_rank", "get_group_size", "get_world_rank_from_group_rank", "GlobalComm", "init", "release", "get_rank", "get_group_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP", "get_group", "get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP",
"get_local_rank", "get_local_rank_size", "destroy_group" "get_local_rank", "get_local_rank_size", "destroy_group"
] ]
...@@ -21,7 +21,7 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ ...@@ -21,7 +21,7 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_group", __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank", "get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group", "get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
...@@ -30,7 +30,7 @@ DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP ...@@ -30,7 +30,7 @@ DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl") DEFAULT_BACKEND = Backend("hccl")
def get_group(group): def _get_group(group):
"""Get the global world group if the group is default world comm group.""" """Get the global world group if the group is default world comm group."""
if group == DEFAULT_WORLD_COMM_GROUP: if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP return GlobalComm.WORLD_COMM_GROUP
...@@ -100,7 +100,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): ...@@ -100,7 +100,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If hccl/nccl is not available or nccl not supports.
""" """
return _get_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND) return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
...@@ -121,7 +121,7 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): ...@@ -121,7 +121,7 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If hccl/nccl is not available or nccl not supports.
""" """
return _get_local_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND) return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
...@@ -139,7 +139,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): ...@@ -139,7 +139,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If hccl/nccl is not available or nccl not supports.
""" """
return _get_size_helper(group=get_group(group), backend=GlobalComm.BACKEND) return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
...@@ -160,7 +160,7 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): ...@@ -160,7 +160,7 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid. ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports. RuntimeError: If hccl/nccl is not available or nccl not supports.
""" """
return _get_local_size_helper(group=get_group(group), backend=GlobalComm.BACKEND) return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_world_rank_from_group_rank(group, group_rank_id): def get_world_rank_from_group_rank(group, group_rank_id):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
...@@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer): ...@@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer):
raise TypeError("The operation of AllReduce should be str.") raise TypeError("The operation of AllReduce should be str.")
if op == ReduceOp.PROD: if op == ReduceOp.PROD:
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.") raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
if not isinstance(get_group(group), str): if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.") raise TypeError("The group of AllReduce should be str.")
self.op = op self.op = op
self.add_prim_attr('group', get_group(group)) self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0) self.add_prim_attr('fusion', 0)
def vm_impl(self, x): def vm_impl(self, x):
...@@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer): ...@@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('group', get_group(group), (str,), self.name) validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(get_group(group)) self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(get_group(group)) self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', get_group(group)) self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
x_shape[0] = x_shape[0] * self.rank_size x_shape[0] = x_shape[0] * self.rank_size
...@@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer): ...@@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name) validator.check_value_type('group', _get_group(group), (str,), self.name)
self.op = op self.op = op
self.rank_size = get_group_size(get_group(group)) self.rank_size = get_group_size(_get_group(group))
self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', get_group(group)) self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
if x_shape[0] % self.rank_size != 0: if x_shape[0] % self.rank_size != 0:
...@@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer): ...@@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('root_rank', root_rank, (int,), self.name) validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name) validator.check_value_type('group', _get_group(group), (str,), self.name)
self.add_prim_attr('group', get_group(group)) self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
...@@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer): ...@@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
"""init AlltoAll""" """init AlltoAll"""
validator.check_value_type('group', get_group(group), (str,), self.name) validator.check_value_type('group', _get_group(group), (str,), self.name)
self.split_count = split_count self.split_count = split_count
self.split_dim = split_dim self.split_dim = split_dim
self.concat_dim = concat_dim self.concat_dim = concat_dim
self.add_prim_attr('group', get_group(group)) self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册