diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 82e54a89e104ffa4e0a36ff796c0af07e04c883b..d5592cf3e05edd9e3ea789981951d9aaaef8143d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -428,59 +428,33 @@ class ShardingOptimizer(MetaOptimizerBase): # pp ring if self.pp_degree > 1: - if self.schedule_mode == 'F-then-B': # GPipe - self._collective_helper._init_communicator( - self._startup_program, - self.current_endpoint, - self.pp_group_endpoints, - self.pp_rank, - self.pp_ring_id, - False, - global_ring_id=self.global_ring_id, - sync=False) - # append_naive_sync(startup_block, self.startup_prog_sync_var, - # self.global_ring_id) + for pair in self.pipeline_pair: + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + print("pp pair:{}, ring_id: {}".format(pair, ring_id)) + if self.pp_rank not in pair: continue + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + if pair[0] < pair[1]: + start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 + else: + start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1 + pp_rank = 0 if self.pp_rank == pair[0] else 1 self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.pp_group_endpoints, - self.pp_rank, - self.pp_ring_id + 2, + pp_group_endpoints, + pp_rank, + ring_id, False, global_ring_id=self.global_ring_id, sync=False) # append_naive_sync(startup_block, self.startup_prog_sync_var, # self.global_ring_id) - else: - assert self.schedule_mode == '1F1B' - for pair in self.pipeline_pair: - pair_key = pair[0] * 1000 + pair[1] - ring_id = self.pp_ring_map[pair_key] - print("pp pair:{}, ring_id: {}".format(pair, ring_id)) - if self.pp_rank not in pair: continue - pp_group_endpoints = [ - self.pp_group_endpoints[pair[0]], - self.pp_group_endpoints[pair[1]], - ] - if pair[0] < pair[1]: - start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 - else: - start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[ - 1] - 1 - pp_rank = 0 if self.pp_rank == pair[0] else 1 - self._collective_helper._init_communicator( - self._startup_program, - self.current_endpoint, - pp_group_endpoints, - pp_rank, - ring_id, - False, - global_ring_id=self.global_ring_id, - sync=False) - # append_naive_sync(startup_block, self.startup_prog_sync_var, - # self.global_ring_id) - - # TODO (JZ-LIANG) to unify this shit + + # TODO (JZ-LIANG) to unify this shit assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( self.pp_rank_, self.pp_rank)