提交 ba3a1f4f 编写于 作者: L lichenever

change get_group to internal interface

上级 96d39886
......@@ -17,12 +17,12 @@ Collective communication interface.
"""
from .management import GlobalComm, init, release, get_rank, get_group_size, get_world_rank_from_group_rank, \
get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, get_group, \
get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
get_local_rank, get_local_rank_size, destroy_group
__all__ = [
"GlobalComm", "init", "release", "get_rank", "get_group_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP", "get_group",
"get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP",
"get_local_rank", "get_local_rank_size", "destroy_group"
]
......@@ -21,7 +21,7 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_group",
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
......@@ -30,7 +30,7 @@ DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
def get_group(group):
def _get_group(group):
"""Get the global world group if the group is default world comm group."""
if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP
......@@ -100,7 +100,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports.
"""
return _get_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND)
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
......@@ -121,7 +121,7 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports.
"""
return _get_local_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND)
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
......@@ -139,7 +139,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports.
"""
return _get_size_helper(group=get_group(group), backend=GlobalComm.BACKEND)
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
......@@ -160,7 +160,7 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If hccl/nccl is not available or nccl not supports.
"""
return _get_local_size_helper(group=get_group(group), backend=GlobalComm.BACKEND)
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_world_rank_from_group_rank(group, group_rank_id):
......
......@@ -17,7 +17,7 @@
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
......@@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer):
raise TypeError("The operation of AllReduce should be str.")
if op == ReduceOp.PROD:
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
if not isinstance(get_group(group), str):
if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.")
self.op = op
self.add_prim_attr('group', get_group(group))
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
def vm_impl(self, x):
......@@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('group', get_group(group), (str,), self.name)
self.rank = get_rank(get_group(group))
self.rank_size = get_group_size(get_group(group))
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', get_group(group))
self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape):
x_shape[0] = x_shape[0] * self.rank_size
......@@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.op = op
self.rank_size = get_group_size(get_group(group))
self.rank_size = get_group_size(_get_group(group))
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', get_group(group))
self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape):
if x_shape[0] % self.rank_size != 0:
......@@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', get_group(group), (str,), self.name)
self.add_prim_attr('group', get_group(group))
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape):
return x_shape
......@@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
"""init AlltoAll"""
validator.check_value_type('group', get_group(group), (str,), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.split_count = split_count
self.split_dim = split_dim
self.concat_dim = concat_dim
self.add_prim_attr('group', get_group(group))
self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape):
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册