diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index ca438326386f0dc7a411310d945961eb94ea7926..c9566b188a54ee0494f73164fb66fc05dd42ebfb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -199,7 +199,8 @@ def send_partial(tensor, dst_rank = _hcg._get_p2p_next_rank( ) if dst == 1 else _hcg._get_p2p_prev_rank() if _in_legacy_dygraph(): - send_op = paddle.distributed.send + send_op = lambda x, dst, group: \ + paddle.distributed.send(x, dst, group, use_calc_stream) elif in_dygraph_mode(): send_op = paddle.distributed.isend return send_op(tensor.detach(), dst=dst_rank, group=group)