From cb6510ffd2246f364050a501cf83bb42f5c845a5 Mon Sep 17 00:00:00 2001 From: WangXi Date: Wed, 14 Jul 2021 14:25:40 +0800 Subject: [PATCH] [hybrid fix] fix pp+dp hang (#34142) --- .../meta_optimizers/sharding_optimizer.py | 46 +++++++++---------- .../test_fleet_sharding_meta_optimizer.py | 5 +- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 0f103c0709a..a74f923dea4 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -434,35 +434,31 @@ class ShardingOptimizer(MetaOptimizerBase): # pp ring if self.pp_degree > 1: + # TODO (JZ-LIANG) to unify this shit + assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( + self.pp_rank_, self.pp_rank) + for pair in self.pipeline_pair: pair_key = pair[0] * 1000 + pair[1] ring_id = self.pp_ring_map[pair_key] print("pp pair:{}, ring_id: {}".format(pair, ring_id)) - if self.pp_rank not in pair: continue - pp_group_endpoints = [ - self.pp_group_endpoints[pair[0]], - self.pp_group_endpoints[pair[1]], - ] - if pair[0] < pair[1]: - start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 - else: - start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1 - pp_rank = 0 if self.pp_rank == pair[0] else 1 - self._collective_helper._init_communicator( - self._startup_program, - self.current_endpoint, - pp_group_endpoints, - pp_rank, - ring_id, - False, - global_ring_id=self.global_ring_id, - sync=False) - # append_naive_sync(startup_block, self.startup_prog_sync_var, - # self.global_ring_id) - - # TODO (JZ-LIANG) to unify this shit - assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( - self.pp_rank_, self.pp_rank) + if self.pp_rank in pair: + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + pp_group_endpoints, + pp_rank, + ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) # pure dp ring if self.dp_degree > 1: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index af020548af3..a29d752ed75 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -525,6 +525,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): startup_prog_op_types = [op.type for op in startup_prog_ops] main_prog_op_types = [op.type for op in main_prog_ops] print(startup_prog_op_types) + # global, sharding, pp_send, pp_recv self.assertEqual(startup_prog_op_types, [ 'fill_constant', 'uniform_random', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', @@ -532,7 +533,9 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream', - 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init' + 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', + 'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', + 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream' ]) self.assertEqual(main_prog_op_types, [ -- GitLab