未验证 提交 4a97ba5d 编写于 作者: R Roc 提交者: GitHub

[KUNLUN]Revert "revert p2p communication for xpu (#53496)" (#53633)

* Revert "revert p2p communication for xpu (#53496)"

This reverts commit eda0c588.

* update
上级 80757527
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import numpy as np import numpy as np
...@@ -27,10 +26,30 @@ _hcg = None ...@@ -27,10 +26,30 @@ _hcg = None
_use_cache = False _use_cache = False
_enable_partial_send_recv = True _enable_partial_send_recv = True
_xpu_comm_group_started = False
_sync_send = os.environ.get("PADDLE_P2P_SYNC_SEND", "0") _sync_send = os.environ.get("PADDLE_P2P_SYNC_SEND", "0")
_sync_send = _sync_send.lower() in ['1', 'true'] _sync_send = _sync_send.lower() in ['1', 'true']
def _xpu_comm_group_start():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
assert not _xpu_comm_group_started
framework.core.ProcessGroupBKCL.group_start()
_xpu_comm_group_started = True
def _xpu_comm_group_end():
if not paddle.is_compiled_with_xpu():
return
global _xpu_comm_group_started
if _xpu_comm_group_started:
framework.core.ProcessGroupBKCL.group_end()
_xpu_comm_group_started = False
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv global _hcg, _use_cache, _enable_partial_send_recv
_hcg = hcg _hcg = hcg
...@@ -357,6 +376,7 @@ def _p2p_helper( ...@@ -357,6 +376,7 @@ def _p2p_helper(
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks = [] tasks = []
# start to p2p communicate # start to p2p communicate
if _sync_send: if _sync_send:
# Some devices(NPU for example) do not support asynchronized send op, So the order is # Some devices(NPU for example) do not support asynchronized send op, So the order is
# recv_prev -> send_next -> recv_next -> send_prev # recv_prev -> send_next -> recv_next -> send_prev
...@@ -492,8 +512,8 @@ def _p2p_helper( ...@@ -492,8 +512,8 @@ def _p2p_helper(
group=_hcg.send_prev_group, group=_hcg.send_prev_group,
use_calc_stream=False, use_calc_stream=False,
) )
else: else:
_xpu_comm_group_start()
if tensor_send_prev is not None: if tensor_send_prev is not None:
if isinstance(tensor_send_prev, tuple): if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev: for d in tensor_send_prev:
...@@ -529,6 +549,7 @@ def _p2p_helper( ...@@ -529,6 +549,7 @@ def _p2p_helper(
use_calc_stream=sync_recv, use_calc_stream=sync_recv,
) )
if sync_recv: if sync_recv:
_xpu_comm_group_end()
allgather_partial( allgather_partial(
d, d,
nranks=mp_degree, nranks=mp_degree,
...@@ -549,6 +570,7 @@ def _p2p_helper( ...@@ -549,6 +570,7 @@ def _p2p_helper(
) )
if sync_recv: if sync_recv:
_xpu_comm_group_end()
allgather_partial( allgather_partial(
tensor_recv_prev, tensor_recv_prev,
nranks=mp_degree, nranks=mp_degree,
...@@ -595,6 +617,7 @@ def _p2p_helper( ...@@ -595,6 +617,7 @@ def _p2p_helper(
) )
if sync_recv: if sync_recv:
_xpu_comm_group_end()
allgather_partial( allgather_partial(
d, d,
nranks=mp_degree, nranks=mp_degree,
...@@ -615,6 +638,7 @@ def _p2p_helper( ...@@ -615,6 +638,7 @@ def _p2p_helper(
use_calc_stream=sync_recv, use_calc_stream=sync_recv,
) )
if sync_recv: if sync_recv:
_xpu_comm_group_end()
allgather_partial( allgather_partial(
tensor_recv_next, tensor_recv_next,
nranks=mp_degree, nranks=mp_degree,
...@@ -624,7 +648,7 @@ def _p2p_helper( ...@@ -624,7 +648,7 @@ def _p2p_helper(
) )
else: else:
tasks.append(task) tasks.append(task)
_xpu_comm_group_end()
if not sync_recv: if not sync_recv:
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
# wait irecv tasks in eager dygraph mode with new comm library # wait irecv tasks in eager dygraph mode with new comm library
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册