diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index ab1cf9701dd63477bc8a0c1ecb35224dc9037dd3..2a7dd4d0bb71038fd64c346c22a91840af7863e5 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -75,6 +75,9 @@ class PipelineParallel(MetaParallelBase): ].dp_comm_overlap self._dp_comm_buffers = [] + if self._dp_comm_overlap: + assert self.use_data_parallel and self.num_stages > 1 + p2p.initialize_p2p_groups( hcg, self._using_cache, self._enable_partial_send_recv ) diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index 89b0456fe6abf0603e96d4a09603ef3e3049ea2b..5abe7c47e9b25abeb7e5af58d2f1c193e26e98f0 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -62,9 +62,16 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): if fleet_env.worker_num() > 1: if not fleet_env._user_defined_strategy.heter_ccl_mode: - return HybridParallelOptimizer( + hp_optim = HybridParallelOptimizer( optimizer, fleet_env._hcg, fleet_env._user_defined_strategy ) + + if fleet_env._user_defined_strategy.hybrid_configs[ + "pp_configs" + ].dp_comm_overlap: + hp_optim._dp_enable = False + + return hp_optim else: return HeterParallelOptimizer( optimizer, fleet_env._user_defined_strategy