diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 5e2f4ba721931299a8f0a01ef709760be9f9f6fb..c508c88015cfda3b3f4071a2dc898ca062b19f51 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -161,12 +161,15 @@ def send_partial(tensor, ring_id = 0 if group is None else group.id if _is_valid_send_recv_partial(tensor, nranks): - return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'peer', dst, 'num', - nranks, 'id', rank_id) + return _C_ops.partial_send(tensor.detach(), 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, 'peer', + dst, 'num', nranks, 'id', rank_id) else: return paddle.distributed.send( - tensor, dst=dst, group=group, use_calc_stream=use_calc_stream) + tensor.detach(), + dst=dst, + group=group, + use_calc_stream=use_calc_stream) def recv_partial(tensor, @@ -180,13 +183,16 @@ def recv_partial(tensor, ring_id = 0 if group is None else group.id if _is_valid_send_recv_partial(tensor, nranks): - _C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream, + _C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', tensor.shape) else: paddle.distributed.recv( - tensor, src=src, group=group, use_calc_stream=use_calc_stream) + tensor.detach(), + src=src, + group=group, + use_calc_stream=use_calc_stream) def allgather_partial(tensor, @@ -200,9 +206,9 @@ def allgather_partial(tensor, return ring_id = 0 if group is None else group.id - return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'nranks', nranks, - 'rank', rank_id) + return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, + 'nranks', nranks, 'rank', rank_id) def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):