From fa16c21f90660c16ec2b4f4b7f6035d1c14c782a Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Fri, 6 Aug 2021 17:16:41 +0800 Subject: [PATCH] fix bug of inplace (#34665) --- .../pp_utils/p2p_communication.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 5e2f4ba7219..c508c88015c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -161,12 +161,15 @@ def send_partial(tensor, ring_id = 0 if group is None else group.id if _is_valid_send_recv_partial(tensor, nranks): - return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'peer', dst, 'num', - nranks, 'id', rank_id) + return _C_ops.partial_send(tensor.detach(), 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, 'peer', + dst, 'num', nranks, 'id', rank_id) else: return paddle.distributed.send( - tensor, dst=dst, group=group, use_calc_stream=use_calc_stream) + tensor.detach(), + dst=dst, + group=group, + use_calc_stream=use_calc_stream) def recv_partial(tensor, @@ -180,13 +183,16 @@ def recv_partial(tensor, ring_id = 0 if group is None else group.id if _is_valid_send_recv_partial(tensor, nranks): - _C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream, + _C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', tensor.shape) else: paddle.distributed.recv( - tensor, src=src, group=group, use_calc_stream=use_calc_stream) + tensor.detach(), + src=src, + group=group, + use_calc_stream=use_calc_stream) def allgather_partial(tensor, @@ -200,9 +206,9 @@ def allgather_partial(tensor, return ring_id = 0 if group is None else group.id - return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'nranks', nranks, - 'rank', rank_id) + return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, + 'nranks', nranks, 'rank', rank_id) def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): -- GitLab