未验证 提交 07878a34 编写于 作者: 张春乔 提交者: GitHub

rm _init_npu_pipeline_comm (#53150)

上级 43b950f7
......@@ -737,89 +737,6 @@ class ShardingOptimizer(MetaOptimizerBase):
sync=False,
)
def _init_npu_pipeline_comm(self, startup_block):
assert (self.pp_degree % 2) == 0
max_ring_id = -1
my_pair = []
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
max_ring_id = max(max_ring_id, ring_id)
logger.info(f"pp pair:{pair}, ring_id: {ring_id}")
if self.pp_rank in pair:
my_pair.append(pair)
# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (
self.pp_rank,
(self.pp_rank + 1) % self.pp_degree,
) # 2->3
recv_from_next_pair = (
(self.pp_rank + 1) % self.pp_degree,
self.pp_rank,
) # 3->2
recv_from_prev_pair = (
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
self.pp_rank,
) # 1->2
send_to_prev_pair = (
self.pp_rank,
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
) # 2->1
even = (self.pp_rank % 2) == 0
# 1. even send to next, odd recv from prev, 0->1, 2->3
pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info(f"pair0(even->odd): pp pair:{pair}, ring_id: {ring_id}")
# 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info(f"pair1(even<-odd): pp pair:{pair}, ring_id: {ring_id}")
# if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 1
) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info(
"pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
# 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 2
) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info(
"pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
assert len(my_pair) == 0, (
"Current pipeline does not support cross stage communication, "
"please check unexpected pair {}".format(my_pair)
)
def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
......@@ -834,7 +751,6 @@ class ShardingOptimizer(MetaOptimizerBase):
)
if core.is_compiled_with_custom_device('npu'):
self._init_npu_pipeline_comm(startup_block)
return
# GPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册