未验证 提交 6d9bbee3 编写于 作者: R Roc 提交者: GitHub

[KUNLUN] fix pp send /recv on xpu (#53427)

To make it synchronized at the first recv operator.
If warping all send and recv operators with group start and end, the received tensor will be not complete.
上级 207e0f33
......@@ -24,6 +24,26 @@ _hcg = None
_use_cache = False
_enable_partial_send_recv = True
_xpu_comm_group_started = False
def _xpu_comm_group_start():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
assert not _xpu_comm_group_started
framework.core.ProcessGroupBKCL.group_start()
_xpu_comm_group_started = True
def _xpu_comm_group_end():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
if _xpu_comm_group_started:
framework.core.ProcessGroupBKCL.group_end()
_xpu_comm_group_started = False
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv
......@@ -350,9 +370,8 @@ def _p2p_helper(
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = []
if paddle.is_compiled_with_xpu():
framework.core.ProcessGroupBKCL.group_start()
# start to p2p communicate
_xpu_comm_group_start()
if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
......@@ -388,6 +407,7 @@ def _p2p_helper(
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
d,
nranks=mp_degree,
......@@ -406,7 +426,9 @@ def _p2p_helper(
group=_hcg.recv_prev_group,
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
tensor_recv_prev,
nranks=mp_degree,
......@@ -451,7 +473,9 @@ def _p2p_helper(
group=_hcg.recv_next_group,
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
d,
nranks=mp_degree,
......@@ -472,6 +496,7 @@ def _p2p_helper(
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
tensor_recv_next,
nranks=mp_degree,
......@@ -481,9 +506,7 @@ def _p2p_helper(
)
else:
tasks.append(task)
if paddle.is_compiled_with_xpu():
framework.core.ProcessGroupBKCL.group_end()
_xpu_comm_group_end()
if not sync_recv:
if framework.in_dygraph_mode():
# wait irecv tasks in eager dygraph mode with new comm library
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册