未验证 提交 ff806111 编写于 作者: L LiYuRio 提交者: GitHub

separate four directions p2p communication to a new file (#54664)

上级 aac91e82
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
import os
from functools import reduce from functools import reduce
from itertools import product from itertools import product
...@@ -24,6 +25,9 @@ from ..utils.log_util import logger ...@@ -24,6 +25,9 @@ from ..utils.log_util import logger
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] __all__ = ['CommunicateTopology', 'HybridCommunicateGroup']
_HYBRID_PARALLEL_GROUP = None _HYBRID_PARALLEL_GROUP = None
_use_four_directions = os.environ.get(
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
class ParallelMode: class ParallelMode:
...@@ -191,7 +195,9 @@ class HybridCommunicateGroup: ...@@ -191,7 +195,9 @@ class HybridCommunicateGroup:
if self._pp_degree > 1: if self._pp_degree > 1:
if paddle.framework.core.is_compiled_with_nccl(): if paddle.framework.core.is_compiled_with_nccl():
check_nccl_version_for_p2p() check_nccl_version_for_p2p()
self._set_p2p_group() self._set_p2p_prev_next()
if _use_four_directions:
self._set_four_directions_p2p_group()
debug_str = ( debug_str = (
"HybridParallelInfo: rank_id: %d, mp_degree: %d, " "HybridParallelInfo: rank_id: %d, mp_degree: %d, "
...@@ -291,7 +297,7 @@ class HybridCommunicateGroup: ...@@ -291,7 +297,7 @@ class HybridCommunicateGroup:
assert hasattr(self, 'prev_rank'), "prev_rank has not been inited" assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
return self.prev_rank return self.prev_rank
def _set_p2p_group(self): def _set_p2p_prev_next(self):
comm_lists = self._topo.get_comm_list('pipe') comm_lists = self._topo.get_comm_list('pipe')
for comm_ranks in comm_lists: for comm_ranks in comm_lists:
...@@ -305,6 +311,43 @@ class HybridCommunicateGroup: ...@@ -305,6 +311,43 @@ class HybridCommunicateGroup:
self.next_rank = next_rank self.next_rank = next_rank
self.prev_rank = prev_rank self.prev_rank = prev_rank
def _set_four_directions_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank]
)
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank]
)
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self): def topology(self):
return self._topo return self._topo
...@@ -357,7 +400,15 @@ class HybridCommunicateGroup: ...@@ -357,7 +400,15 @@ class HybridCommunicateGroup:
return self._pp_comm_group return self._pp_comm_group
def get_p2p_groups(self): def get_p2p_groups(self):
return None assert (
_use_four_directions
), "If you want to use four directions p2p group, set the environment variable PADDLE_USE_FOUR_DIRECTIONS_P2P to True."
return (
self.send_next_group,
self.send_prev_group,
self.recv_next_group,
self.recv_prev_group,
)
# sharding parallel message: # sharding parallel message:
def _get_sharding_parallel_id(self): def _get_sharding_parallel_id(self):
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
import time import time
import warnings import warnings
import os
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -26,7 +28,15 @@ from ..utils.hybrid_parallel_util import ( ...@@ -26,7 +28,15 @@ from ..utils.hybrid_parallel_util import (
from ..utils.log_util import logger from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer from .parallel_layers.pp_layers import PipelineLayer
from .pp_utils import p2p_communication as p2p
_use_four_directions = os.environ.get(
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
if _use_four_directions:
from .pp_utils import four_directions_p2p_communication as p2p
else:
from .pp_utils import p2p_communication as p2p
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册