diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cb3c37975ddf44e28a8ae40043a447e4683489ac..2b49f430df1aa2ebef63609b525fb5a2645b20c8 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -81,9 +81,19 @@ class _Group(): self.nranks = rank_num -_default_group = _Group( - int(os.getenv("PADDLE_TRAINER_ID", "0")), - int(os.getenv("PADDLE_TRAINERS_NUM", "1"))) +# NOTE(chenweihang): Lazily initialized global group information +# If we initialize _default_group when import module, it will +# not update when we use spawn to run multi-process training +_default_group = None + + +def _get_global_default_group(): + global _default_group + if _default_group is None: + _default_group = _Group( + int(os.getenv("PADDLE_TRAINER_ID", "0")), + int(os.getenv("PADDLE_TRAINERS_NUM", "1"))) + return _default_group def broadcast(tensor, src, group=0): @@ -339,6 +349,7 @@ def all_gather(tensor_list, tensor, group=0): op_type = 'c_allgather' helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + _default_group = _get_global_default_group() if in_dygraph_mode(): core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id', group, 'nranks', _default_group.nranks) @@ -410,7 +421,7 @@ def scatter(tensor, tensor_list=None, src=0, group=0): out = data1.numpy() """ op_type = 'c_scatter' - global _default_group + _default_group = _get_global_default_group() rank = _default_group.rank nranks = _default_group.nranks if rank != src: