diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 5775a734c870690613d474f279ba2e20b720b759..be30604098fe9c22a47c4d0c57306baa907b33b3 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -239,31 +239,37 @@ def new_group(ranks=None, backend=None): if global_rank not in ranks: gp = Group(-1, -1, ring_id, ranks) _group_map[ring_id] = gp - return gp - - ranks = sorted(ranks) - group_rank = ranks.index(global_rank) - group_size = len(ranks) - gp = Group(group_rank, group_size, ring_id, ranks) - _group_map[ring_id] = gp - - if group_size < 2: - return gp - - strategy = core.ParallelStrategy() - strategy.nranks = group_size - strategy.local_rank = group_rank - strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks] - strategy.current_endpoint = genv.current_endpoint - strategy.nrings = 1 - - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(genv.device_id) - core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id) else: - assert False, ("no cuda device found") - # need to barrier to construct group - barrier(gp) + ranks = sorted(ranks) + group_rank = ranks.index(global_rank) + group_size = len(ranks) + gp = Group(group_rank, group_size, ring_id, ranks) + _group_map[ring_id] = gp + + if group_size >= 2: + strategy = core.ParallelStrategy() + strategy.nranks = group_size + strategy.local_rank = group_rank + strategy.trainer_endpoints = [ + genv.trainer_endpoints[i] for i in ranks + ] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, + place).init_with_ring_id(ring_id) + else: + assert False, ("no cuda device found") + else: + return gp + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + tmp = fill_constant([0], dtype="int32", value="1") + paddle.distributed.all_reduce(tmp, use_calc_stream=True) + paddle.distributed.wait(tmp) return gp