diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 2b9899c6982317f678a04bcd56ac435d5da65948..687de7adea8389443d0bec9d94d189af852eb9db 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -66,14 +66,20 @@ class CollectiveHelper(object): self.role_maker._worker_index(), ring_id, self.wait_port) self._broadcast_params() - def _init_communicator(self, program, current_endpoint, endpoints, rank, - ring_id, wait_port): + def _init_communicator(self, + program, + current_endpoint, + endpoints, + rank, + ring_id, + wait_port, + sync=True): nranks = len(endpoints) other_endpoints = endpoints[:] other_endpoints.remove(current_endpoint) block = program.global_block() if core.is_compiled_with_cuda(): - if not wait_port: + if not wait_port and sync: temp_var = block.create_var( name=unique_name.generate('temp_var'), dtype=core.VarDesc.VarType.INT32, diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 38dcc14427eb09b5cb0cc3e11961403ef4f0281b..f8e26ee2406f18152dd7c771c2f75ed05315824a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -96,6 +96,8 @@ class ShardingOptimizer(MetaOptimizerBase): "use_pipeline"] self.acc_steps = self.user_defined_strategy.sharding_configs[ "acc_steps"] + self.schedule_mode = self.user_defined_strategy.sharding_configs[ + "schedule_mode"] if self.inner_opt is None: raise ValueError( @@ -105,6 +107,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.acc_steps) main_program = loss.block.program main_program._pipeline_opt = dict() + main_program._pipeline_opt['schedule_mode'] = self.schedule_mode pp_rank = self.role_maker._worker_index() // ( self.user_defined_strategy.sharding_configs[ 'sharding_group_size'] * self._inner_parallelism_size) @@ -409,20 +412,33 @@ class ShardingOptimizer(MetaOptimizerBase): print("pp_group_endpoints:", self.pp_group_endpoints) print("pp_rank:", self.pp_rank) print("pp_ring_id:", self.pp_ring_id) - for pair in self.pipeline_pair: - if self.pp_rank not in pair: continue - pp_group_endpoints = [ - self.pp_group_endpoints[pair[0]], - self.pp_group_endpoints[pair[1]], - ] - if pair[0] < pair[1]: - start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 - else: - start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1 - pp_rank = 0 if self.pp_rank == pair[0] else 1 + if self.schedule_mode == 0: # GPipe + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, + False) self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - pp_group_endpoints, pp_rank, start_ring_id, False) + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id + 2, + False) + else: + for pair in self.pipeline_pair: + print("pp pair:{}".format(pair)) + if self.pp_rank not in pair: continue + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + if pair[0] < pair[1]: + start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 + else: + start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[ + 1] - 1 + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + pp_group_endpoints, pp_rank, start_ring_id, False, + False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index e9b9bca380433f717b3f7511c06382e1a5b0ad2c..3c5906ceb9df9558dd8fbf7080422a492eeddb8e 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -415,6 +415,7 @@ class Section(DeviceWorker): section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.pipeline_stage = pipeline_opt["pipeline_stage"] section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"] + section_param.schedule_mode = pipeline_opt["schedule_mode"] cfg = section_param.section_config program = pipeline_opt["section_program"] cfg.program_desc.ParseFromString(program["program"]._get_desc()