未验证 提交 fa16c21f 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of inplace (#34665)

上级 8a9dc5dc
...@@ -161,12 +161,15 @@ def send_partial(tensor, ...@@ -161,12 +161,15 @@ def send_partial(tensor,
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream, return _C_ops.partial_send(tensor.detach(), 'use_calc_stream',
'ring_id', ring_id, 'peer', dst, 'num', use_calc_stream, 'ring_id', ring_id, 'peer',
nranks, 'id', rank_id) dst, 'num', nranks, 'id', rank_id)
else: else:
return paddle.distributed.send( return paddle.distributed.send(
tensor, dst=dst, group=group, use_calc_stream=use_calc_stream) tensor.detach(),
dst=dst,
group=group,
use_calc_stream=use_calc_stream)
def recv_partial(tensor, def recv_partial(tensor,
...@@ -180,13 +183,16 @@ def recv_partial(tensor, ...@@ -180,13 +183,16 @@ def recv_partial(tensor,
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream, _C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'num', nranks, 'ring_id', ring_id, 'peer', src, 'num', nranks,
'id', rank_id, 'dtype', tensor.dtype, 'out_shape', 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape) tensor.shape)
else: else:
paddle.distributed.recv( paddle.distributed.recv(
tensor, src=src, group=group, use_calc_stream=use_calc_stream) tensor.detach(),
src=src,
group=group,
use_calc_stream=use_calc_stream)
def allgather_partial(tensor, def allgather_partial(tensor,
...@@ -200,9 +206,9 @@ def allgather_partial(tensor, ...@@ -200,9 +206,9 @@ def allgather_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream, return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream',
'ring_id', ring_id, 'nranks', nranks, use_calc_stream, 'ring_id', ring_id,
'rank', rank_id) 'nranks', nranks, 'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册