diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cd03e55f25f61b53e126a18567995b02ce9c5ee7..5f481bd0dca41d64b7a5850ee862d33280c3373c 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -403,6 +403,11 @@ def new_group(ranks=None, backend=None): _group_map_by_name[group_name] = group _group_map[gid] = group + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by tcp + tmp = paddle.to_tensor([1], dtype="int32") + paddle.distributed.all_reduce(tmp, group=group, use_calc_stream=True) + paddle.distributed.wait(tmp) return group if not backend: diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index b6698a200e9454cbf948cf8240bac2c22c5ed5a4..de36f8503a651f96753ef7acbd26ba05458192db 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -18,6 +18,7 @@ from ...utils.log_util import logger import numpy as np from paddle import _C_ops import paddle.fluid.core as core +from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode _hcg = None _use_cache = False @@ -148,9 +149,15 @@ _send_recv_meta = SendRecvMeta() def _is_valid_send_recv_partial(tensor, mp_degree): - tensor_numel = np.prod(tensor.shape) - assert tensor_numel != 0, "can't send/recv zero element" - return mp_degree > 1 and tensor_numel % mp_degree == 0 + + if _in_legacy_dygraph(): + tensor_numel = np.prod(tensor.shape) + assert tensor_numel != 0, "can't send/recv zero element" + return mp_degree > 1 and tensor_numel % mp_degree == 0 + elif in_dygraph_mode(): + # TODO(shenliang03) support mp+pp optimizer in future. + # (partial_send/partial_recv/partial_allgather_) + return False def send_partial(tensor,