未验证 提交 6c8a10a2 编写于 作者: S ShenLiang 提交者: GitHub

rm detach (#34644)

上级 6151ccd4
...@@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_prev: for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial( send_partial(
d.detach(), d,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial( send_partial(
tensor_send_prev.detach(), tensor_send_prev,
dst=0, dst=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_prev, tuple): if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev: for d in tensor_recv_prev:
recv_partial( recv_partial(
d.detach(), d,
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
d.detach(), d,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
use_calc_stream=True) use_calc_stream=True)
else: else:
recv_partial( recv_partial(
tensor_recv_prev.detach(), tensor_recv_prev,
src=0, src=0,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_prev_group, group=_hcg.recv_prev_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
tensor_recv_prev.detach(), tensor_recv_prev,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
...@@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_next: for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True) paddle.distributed.wait(d, use_calc_stream=True)
send_partial( send_partial(
d.detach(), d,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True) paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial( send_partial(
tensor_send_next.detach(), tensor_send_next,
dst=1, dst=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -329,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -329,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_next, tuple): if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next: for d in tensor_recv_next:
recv_partial( recv_partial(
d.detach(), d,
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=_hcg.recv_next_group, group=_hcg.recv_next_group,
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
d.detach(), d,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
...@@ -344,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -344,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else: else:
recv_partial( recv_partial(
tensor_recv_next.detach(), tensor_recv_next,
src=1, src=1,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
...@@ -352,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -352,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_calc_stream=True) use_calc_stream=True)
allgather_partial( allgather_partial(
tensor_recv_next.detach(), tensor_recv_next,
nranks=mp_degree, nranks=mp_degree,
rank_id=mp_rank, rank_id=mp_rank,
group=mp_group, group=mp_group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册