diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index b92b2a3c15decd66206e7ec78b749cfbda64241d..65018e7e2e6e752145f1d35ff4787532f5861496 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -378,6 +378,7 @@ def new_group(ranks=None, backend=None): _group_map_by_name[group_name] = group _group_map[gid] = group + paddle.distributed.barrier(group=group) return group if not backend: diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index f0365cab8c896689ba2f49020878e3236d3381a9..13a027db37eddc384e2a395d5d0f0ef9a51e9275 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -20,6 +20,7 @@ from multiprocessing import Manager # noqa: F401 import time import sys +import paddle from paddle import compat as cpt # deprecated module import @@ -253,6 +254,7 @@ def init_parallel_env(): _set_group_map_by_name(_default_group_name, group) _set_group_map(0, group) parallel_helper._set_parallel_ctx(True) + paddle.distributed.barrier(group=group) return group node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints])