未验证 提交 ff806111 编写于 作者: L LiYuRio 提交者: GitHub

separate four directions p2p communication to a new file (#54664)

上级 aac91e82
......@@ -13,6 +13,7 @@
# limitations under the License.
import collections
import os
from functools import reduce
from itertools import product
......@@ -24,6 +25,9 @@ from ..utils.log_util import logger
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup']
_HYBRID_PARALLEL_GROUP = None
_use_four_directions = os.environ.get(
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
class ParallelMode:
......@@ -191,7 +195,9 @@ class HybridCommunicateGroup:
if self._pp_degree > 1:
if paddle.framework.core.is_compiled_with_nccl():
check_nccl_version_for_p2p()
self._set_p2p_group()
self._set_p2p_prev_next()
if _use_four_directions:
self._set_four_directions_p2p_group()
debug_str = (
"HybridParallelInfo: rank_id: %d, mp_degree: %d, "
......@@ -291,7 +297,7 @@ class HybridCommunicateGroup:
assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
return self.prev_rank
def _set_p2p_group(self):
def _set_p2p_prev_next(self):
comm_lists = self._topo.get_comm_list('pipe')
for comm_ranks in comm_lists:
......@@ -305,6 +311,43 @@ class HybridCommunicateGroup:
self.next_rank = next_rank
self.prev_rank = prev_rank
def _set_four_directions_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank]
)
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank]
)
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self):
return self._topo
......@@ -357,7 +400,15 @@ class HybridCommunicateGroup:
return self._pp_comm_group
def get_p2p_groups(self):
return None
assert (
_use_four_directions
), "If you want to use four directions p2p group, set the environment variable PADDLE_USE_FOUR_DIRECTIONS_P2P to True."
return (
self.send_next_group,
self.send_prev_group,
self.recv_next_group,
self.recv_prev_group,
)
# sharding parallel message:
def _get_sharding_parallel_id(self):
......
......@@ -13,6 +13,8 @@
import time
import warnings
import os
import paddle
from paddle import framework
......@@ -26,7 +28,15 @@ from ..utils.hybrid_parallel_util import (
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
from .pp_utils import p2p_communication as p2p
_use_four_directions = os.environ.get(
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
if _use_four_directions:
from .pp_utils import four_directions_p2p_communication as p2p
else:
from .pp_utils import p2p_communication as p2p
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册