未验证 提交 41e2d413 编写于 作者: W WangXi 提交者: GitHub

[NPU] fix npu pipeline comm init (#34466)

上级 8b72a1a7
......@@ -379,6 +379,119 @@ class ShardingOptimizer(MetaOptimizerBase):
self._wait()
return optimize_ops, params_grads
def _init_pair_comm(self, pair, ring_id):
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number
assert (self.pp_degree % 2) == 0
max_ring_id = -1
my_pair = []
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
max_ring_id = max(max_ring_id, ring_id)
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
my_pair.append(pair)
# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (self.pp_rank,
(self.pp_rank + 1) % self.pp_degree) # 2->3
recv_from_next_pair = ((self.pp_rank + 1) % self.pp_degree,
self.pp_rank) # 3->2
recv_from_prev_pair = ((self.pp_rank - 1 + self.pp_degree) %
self.pp_degree, self.pp_rank) # 1->2
send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) %
self.pp_degree) # 2->1
even = (self.pp_rank % 2) == 0
# 1. even send to next, odd recv from prev, 0->1, 2->3
pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))
# 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))
# if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1],
max_ring_id + 1) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id))
# 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1],
max_ring_id + 2) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id))
assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \
"please check unexpected pair {}".format(my_pair)
def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block)
return
# GPU
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
def _init_comm(self):
# config sharding & dp groups
......@@ -435,31 +548,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring
if self.pp_degree > 1:
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
self._init_pipeline_comm(startup_block)
# pure dp ring
if self.dp_degree > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册