diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 8211f3ea0fb9d2e25a01e9181a87a01e531833c9..8a591120c0289edd373582f7173d607219b816a2 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -379,6 +379,119 @@ class ShardingOptimizer(MetaOptimizerBase): self._wait() return optimize_ops, params_grads + def _init_pair_comm(self, pair, ring_id): + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + pp_group_endpoints, + pp_rank, + ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) + + def _init_npu_pipeline_comm(self, startup_block): + # NOTE(wangxi): some bug with hccl, must set pp_degree be even number + assert (self.pp_degree % 2) == 0 + + max_ring_id = -1 + my_pair = [] + for pair in self.pipeline_pair: + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + max_ring_id = max(max_ring_id, ring_id) + logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id)) + + if self.pp_rank in pair: + my_pair.append(pair) + + # for example: self.pp_rank=2, self.pp_degree=4 + send_to_next_pair = (self.pp_rank, + (self.pp_rank + 1) % self.pp_degree) # 2->3 + recv_from_next_pair = ((self.pp_rank + 1) % self.pp_degree, + self.pp_rank) # 3->2 + recv_from_prev_pair = ((self.pp_rank - 1 + self.pp_degree) % + self.pp_degree, self.pp_rank) # 1->2 + send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) % + self.pp_degree) # 2->1 + + even = (self.pp_rank % 2) == 0 + + # 1. even send to next, odd recv from prev, 0->1, 2->3 + pair = send_to_next_pair if even else recv_from_prev_pair + ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] + self._init_pair_comm(pair, ring_id) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + my_pair.remove(pair) + logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, + ring_id)) + + # 2. even recv from next, odd send to prev, 1->0, 3->2 + pair = recv_from_next_pair if even else send_to_prev_pair + ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] + self._init_pair_comm(pair, ring_id) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + my_pair.remove(pair) + logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, + ring_id)) + + # if pp_degree is 2, only need pair(0->1, 1->0) + if self.pp_degree > 2: + # 3. odd send to next, even recv from prev, 1->2, 3->0 + pair = send_to_next_pair if not even else recv_from_prev_pair + ring_id = self.pp_ring_map.get( + pair[0] * 1000 + pair[1], + max_ring_id + 1) # 3->0 not in pp_ring_map + self._init_pair_comm(pair, ring_id) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: + my_pair.remove(pair) + logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( + pair, ring_id)) + + # 4. odd recv from next, even send to prev, 2->1, 0->3 + pair = recv_from_next_pair if not even else send_to_prev_pair + ring_id = self.pp_ring_map.get( + pair[0] * 1000 + pair[1], + max_ring_id + 2) # 0->3 not in pp_ring_map + self._init_pair_comm(pair, ring_id) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: + my_pair.remove(pair) + logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( + pair, ring_id)) + + assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \ + "please check unexpected pair {}".format(my_pair) + + def _init_pipeline_comm(self, startup_block): + # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank + assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( + self.pp_rank_, self.pp_rank) + + if core.is_compiled_with_npu(): + self._init_npu_pipeline_comm(startup_block) + return + + # GPU + for pair in self.pipeline_pair: + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id)) + if self.pp_rank in pair: + self._init_pair_comm(pair, ring_id) + append_naive_sync(startup_block, self.startup_prog_sync_var, + self.global_ring_id) + def _init_comm(self): # config sharding & dp groups @@ -435,31 +548,7 @@ class ShardingOptimizer(MetaOptimizerBase): # pp ring if self.pp_degree > 1: - # TODO (JZ-LIANG) to unify this shit - assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( - self.pp_rank_, self.pp_rank) - - for pair in self.pipeline_pair: - pair_key = pair[0] * 1000 + pair[1] - ring_id = self.pp_ring_map[pair_key] - print("pp pair:{}, ring_id: {}".format(pair, ring_id)) - if self.pp_rank in pair: - pp_group_endpoints = [ - self.pp_group_endpoints[pair[0]], - self.pp_group_endpoints[pair[1]], - ] - pp_rank = 0 if self.pp_rank == pair[0] else 1 - self._collective_helper._init_communicator( - self._startup_program, - self.current_endpoint, - pp_group_endpoints, - pp_rank, - ring_id, - False, - global_ring_id=self.global_ring_id, - sync=False) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) + self._init_pipeline_comm(startup_block) # pure dp ring if self.dp_degree > 1: