未验证 提交 fa16c21f 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of inplace (#34665)

上级 8a9dc5dc
......@@ -161,12 +161,15 @@ def send_partial(tensor,
ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst, 'num',
nranks, 'id', rank_id)
return _C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', nranks, 'id', rank_id)
else:
return paddle.distributed.send(
tensor, dst=dst, group=group, use_calc_stream=use_calc_stream)
tensor.detach(),
dst=dst,
group=group,
use_calc_stream=use_calc_stream)
def recv_partial(tensor,
......@@ -180,13 +183,16 @@ def recv_partial(tensor,
ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream,
_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'num', nranks,
'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else:
paddle.distributed.recv(
tensor, src=src, group=group, use_calc_stream=use_calc_stream)
tensor.detach(),
src=src,
group=group,
use_calc_stream=use_calc_stream)
def allgather_partial(tensor,
......@@ -200,9 +206,9 @@ def allgather_partial(tensor,
return
ring_id = 0 if group is None else group.id
return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks,
'rank', rank_id)
return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'nranks', nranks, 'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册