未验证 提交 350cd82a 编写于 作者: R Roc 提交者: GitHub

[kunlun] support async send/recv via group (#50329)

Co-authored-by: Nzhangxiaoci <zhangxiaoci@baidu.com>
上级 3862f347
...@@ -1251,7 +1251,9 @@ void BindDistributed(py::module *m) { ...@@ -1251,7 +1251,9 @@ void BindDistributed(py::module *m) {
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>())
.def_static("group_start", distributed::ProcessGroupBKCL::GroupStart)
.def_static("group_end", distributed::ProcessGroupBKCL::GroupEnd);
#endif #endif
py::class_<distributed::ProcessGroup::Task, py::class_<distributed::ProcessGroup::Task,
......
...@@ -350,6 +350,8 @@ def _p2p_helper( ...@@ -350,6 +350,8 @@ def _p2p_helper(
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = [] tasks = []
if paddle.is_compiled_with_xpu():
framework.core.ProcessGroupBKCL.group_start()
# start to p2p communicate # start to p2p communicate
if tensor_send_prev is not None: if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple): if isinstance(tensor_send_prev, tuple):
...@@ -479,6 +481,8 @@ def _p2p_helper( ...@@ -479,6 +481,8 @@ def _p2p_helper(
) )
else: else:
tasks.append(task) tasks.append(task)
if paddle.is_compiled_with_xpu():
framework.core.ProcessGroupBKCL.group_end()
if not sync_recv: if not sync_recv:
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册