From 1e9299aa3ebfa53f101952ddac45796d756b8e55 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Fri, 4 Jun 2021 14:38:36 +0800 Subject: [PATCH] Fix hang of hybrid parallel in new_group (#33141) * fix hang of hybrid parallel * fix new_group for hang problem --- python/paddle/distributed/collective.py | 54 ++++++++++++++----------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 5775a734c8..be30604098 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -239,31 +239,37 @@ def new_group(ranks=None, backend=None): if global_rank not in ranks: gp = Group(-1, -1, ring_id, ranks) _group_map[ring_id] = gp - return gp - - ranks = sorted(ranks) - group_rank = ranks.index(global_rank) - group_size = len(ranks) - gp = Group(group_rank, group_size, ring_id, ranks) - _group_map[ring_id] = gp - - if group_size < 2: - return gp - - strategy = core.ParallelStrategy() - strategy.nranks = group_size - strategy.local_rank = group_rank - strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks] - strategy.current_endpoint = genv.current_endpoint - strategy.nrings = 1 - - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(genv.device_id) - core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id) else: - assert False, ("no cuda device found") - # need to barrier to construct group - barrier(gp) + ranks = sorted(ranks) + group_rank = ranks.index(global_rank) + group_size = len(ranks) + gp = Group(group_rank, group_size, ring_id, ranks) + _group_map[ring_id] = gp + + if group_size >= 2: + strategy = core.ParallelStrategy() + strategy.nranks = group_size + strategy.local_rank = group_rank + strategy.trainer_endpoints = [ + genv.trainer_endpoints[i] for i in ranks + ] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, + place).init_with_ring_id(ring_id) + else: + assert False, ("no cuda device found") + else: + return gp + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + tmp = fill_constant([0], dtype="int32", value="1") + paddle.distributed.all_reduce(tmp, use_calc_stream=True) + paddle.distributed.wait(tmp) return gp -- GitLab