未验证 提交 cb6510ff 编写于 作者: W WangXi 提交者: GitHub

[hybrid fix] fix pp+dp hang (#34142)

上级 7f26453f
......@@ -434,19 +434,19 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring
if self.pp_degree > 1:
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
if self.pp_rank in pair:
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
......@@ -457,12 +457,8 @@ class ShardingOptimizer(MetaOptimizerBase):
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# pure dp ring
if self.dp_degree > 1:
......
......@@ -525,6 +525,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
print(startup_prog_op_types)
# global, sharding, pp_send, pp_recv
self.assertEqual(startup_prog_op_types, [
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
......@@ -532,7 +533,9 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream'
])
self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册