未验证 提交 4a97ba5d 编写于 作者: R Roc 提交者: GitHub

[KUNLUN]Revert "revert p2p communication for xpu (#53496)" (#53633)

* Revert "revert p2p communication for xpu (#53496)"

This reverts commit eda0c588.

* update
上级 80757527
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
......@@ -27,10 +26,30 @@ _hcg = None
_use_cache = False
_enable_partial_send_recv = True
_xpu_comm_group_started = False
_sync_send = os.environ.get("PADDLE_P2P_SYNC_SEND", "0")
_sync_send = _sync_send.lower() in ['1', 'true']
def _xpu_comm_group_start():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
assert not _xpu_comm_group_started
framework.core.ProcessGroupBKCL.group_start()
_xpu_comm_group_started = True
def _xpu_comm_group_end():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
if _xpu_comm_group_started:
framework.core.ProcessGroupBKCL.group_end()
_xpu_comm_group_started = False
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv
_hcg = hcg
......@@ -357,6 +376,7 @@ def _p2p_helper(
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = []
# start to p2p communicate
if _sync_send:
# Some devices(NPU for example) do not support asynchronized send op, So the order is
# recv_prev -> send_next -> recv_next -> send_prev
......@@ -492,8 +512,8 @@ def _p2p_helper(
group=_hcg.send_prev_group,
use_calc_stream=False,
)
else:
_xpu_comm_group_start()
if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
......@@ -529,6 +549,7 @@ def _p2p_helper(
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
d,
nranks=mp_degree,
......@@ -549,6 +570,7 @@ def _p2p_helper(
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
tensor_recv_prev,
nranks=mp_degree,
......@@ -595,6 +617,7 @@ def _p2p_helper(
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
d,
nranks=mp_degree,
......@@ -615,6 +638,7 @@ def _p2p_helper(
use_calc_stream=sync_recv,
)
if sync_recv:
_xpu_comm_group_end()
allgather_partial(
tensor_recv_next,
nranks=mp_degree,
......@@ -624,7 +648,7 @@ def _p2p_helper(
)
else:
tasks.append(task)
_xpu_comm_group_end()
if not sync_recv:
if framework.in_dygraph_mode():
# wait irecv tasks in eager dygraph mode with new comm library
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册