未验证 提交 cb6510ff 编写于 作者: W WangXi 提交者: GitHub

[hybrid fix] fix pp+dp hang (#34142)

上级 7f26453f
...@@ -434,35 +434,31 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -434,35 +434,31 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
for pair in self.pipeline_pair: for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1] pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key] ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id)) print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue if self.pp_rank in pair:
pp_group_endpoints = [ pp_group_endpoints = [
self.pp_group_endpoints[pair[0]], self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]], self.pp_group_endpoints[pair[1]],
] ]
if pair[0] < pair[1]: pp_rank = 0 if self.pp_rank == pair[0] else 1
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 self._collective_helper._init_communicator(
else: self._startup_program,
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1 self.current_endpoint,
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_group_endpoints,
self._collective_helper._init_communicator( pp_rank,
self._startup_program, ring_id,
self.current_endpoint, False,
pp_group_endpoints, global_ring_id=self.global_ring_id,
pp_rank, sync=False)
ring_id, append_naive_sync(startup_block, self.startup_prog_sync_var,
False, self.global_ring_id)
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
# pure dp ring # pure dp ring
if self.dp_degree > 1: if self.dp_degree > 1:
......
...@@ -525,6 +525,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -525,6 +525,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
startup_prog_op_types = [op.type for op in startup_prog_ops] startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops] main_prog_op_types = [op.type for op in main_prog_ops]
print(startup_prog_op_types) print(startup_prog_op_types)
# global, sharding, pp_send, pp_recv
self.assertEqual(startup_prog_op_types, [ self.assertEqual(startup_prog_op_types, [
'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
...@@ -532,7 +533,9 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -532,7 +533,9 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream', 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init' 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册