未验证 提交 bff4179c 编写于 作者: C Chen Weihang 提交者: GitHub

lazily init global group in collective (#28780)

上级 a22ea652
......@@ -81,9 +81,19 @@ class _Group():
self.nranks = rank_num
_default_group = _Group(
int(os.getenv("PADDLE_TRAINER_ID", "0")),
int(os.getenv("PADDLE_TRAINERS_NUM", "1")))
# NOTE(chenweihang): Lazily initialized global group information
# If we initialize _default_group when import module, it will
# not update when we use spawn to run multi-process training
_default_group = None
def _get_global_default_group():
global _default_group
if _default_group is None:
_default_group = _Group(
int(os.getenv("PADDLE_TRAINER_ID", "0")),
int(os.getenv("PADDLE_TRAINERS_NUM", "1")))
return _default_group
def broadcast(tensor, src, group=0):
......@@ -339,6 +349,7 @@ def all_gather(tensor_list, tensor, group=0):
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
_default_group = _get_global_default_group()
if in_dygraph_mode():
core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id',
group, 'nranks', _default_group.nranks)
......@@ -410,7 +421,7 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
out = data1.numpy()
"""
op_type = 'c_scatter'
global _default_group
_default_group = _get_global_default_group()
rank = _default_group.rank
nranks = _default_group.nranks
if rank != src:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册