未验证 提交 6c8a10a2 编写于 作者: S ShenLiang 提交者: GitHub

rm detach (#34644)

上级 6151ccd4
......@@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d.detach(),
d,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(
tensor_send_prev.detach(),
tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(
d.detach(),
d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
d.detach(),
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_prev.detach(),
tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_prev.detach(),
tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......@@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d.detach(),
d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(
tensor_send_next.detach(),
tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -329,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(
d.detach(),
d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
d.detach(),
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......@@ -344,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
recv_partial(
tensor_recv_next.detach(),
tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
......@@ -352,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_calc_stream=True)
allgather_partial(
tensor_recv_next.detach(),
tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册