diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 0fe10dd839e1bc3538532fb849066d49543a27ce..89ee08126621340a3f9348453921053816d02ac7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -39,7 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", - "ModelParallelOptimizer", + # "ModelParallelOptimizer", "PipelineOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] @@ -358,6 +358,19 @@ class ShardingOptimizer(MetaOptimizerBase): self._nrings_sharding) # config sharding & dp groups self._init_comm() + + # inner & outer model parallelism + if self._as_outer_parallelism: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.global_group_endpoints, self.global_rank, + self.global_group_id, True) + + if self._as_outer_parallelism: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) + # sharding print("sharding_group_endpoints:", self.sharding_group_endpoints) print("sharding_rank:", self.sharding_rank) @@ -365,13 +378,7 @@ class ShardingOptimizer(MetaOptimizerBase): self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, - self.sharding_ring_id, True) - - # inner & outer model parallelism - # if self._as_outer_parallelism: - # self._collective_helper._init_communicator( - # self._startup_program, self.current_endpoint, - # self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True) + self.sharding_ring_id, False) # dp if self.hybrid_dp: @@ -382,7 +389,7 @@ class ShardingOptimizer(MetaOptimizerBase): if self.use_pipeline: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True) + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -482,7 +489,7 @@ class ShardingOptimizer(MetaOptimizerBase): # group. and each Data Parallelism group should have its own sync of FoundInfinite Model_Paramllelism_ring_id = self.sharding_ring_id if self._as_outer_parallelism: - Model_Paramllelism_ring_id = self.mp_group_id + Model_Paramllelism_ring_id = self.global_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, Model_Paramllelism_ring_id) gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) @@ -826,23 +833,42 @@ class ShardingOptimizer(MetaOptimizerBase): ) == self.sharding_rank ] else: + self.mp_group_id = 0 self.sharding_ring_id = 1 self.pp_ring_id = 2 + self.mp_rank = self.global_rank % self._inner_parallelism_size + self.mp_group = self.global_rank // self._inner_parallelism_size + self.mp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if idx // self._inner_parallelism_size == self.mp_group + ] + print("megatron_group_endpoints:", self.mp_group_endpoints) + print("megatron_rank:", self.mp_rank) # self.cards_per_node = 8 self.sharding_group_size = self.user_defined_strategy.sharding_configs[ 'sharding_group_size'] - self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size - # self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size) + self.sharding_rank = ( + self.global_rank // + self._inner_parallelism_size) % self.sharding_group_size + self.sharding_group_id = self.global_rank // ( + self._inner_parallelism_size * self.sharding_group_size) + self.megatron_rank = self.global_rank % self._inner_parallelism_size self.sharding_group_endpoints = [ ep for idx, ep in enumerate(self.endpoints) - if (idx // self._inner_parallelism_size % - self.sharding_group_size) == self.sharding_rank + if (idx // (self._inner_parallelism_size * + self.sharding_group_size) + ) == self.sharding_group_id and idx % + self._inner_parallelism_size == self.megatron_rank ] + print("sharding_endpoint:", self.sharding_group_endpoints) + print("sharding_rank:", self.sharding_rank) assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( ) self.pp_rank = self.global_rank // ( - self.sharding_group_size * self._inner_parallelism_size) + self.sharding_group_size * + self._inner_parallelism_size) % self.pipeline_nodes offset = self.sharding_group_size * self._inner_parallelism_size + # TODO: Adjust for dp idx_with_pp_0 = self.global_rank % ( self.sharding_group_size * self._inner_parallelism_size) self.pp_group_endpoints = [] @@ -850,15 +876,17 @@ class ShardingOptimizer(MetaOptimizerBase): self.pp_group_endpoints.append(self.endpoints[ idx_with_pp_0]) idx_with_pp_0 += offset + print("pp_group_endpoints:", self.pp_group_endpoints) + print("pp_rank:", self.pp_rank) #self.pp_group_endpoints = [ # ep for idx, ep in enumerate(self.endpoints) # if (idx % self.sharding_group_size) == self.sharding_rank #] - self.mp_group_id = 1 - self.mp_rank = self.global_rank - self.mp_group_size = self.role_maker._worker_num() - self.mp_group_endpoints = self.endpoints[:] + self.global_group_id = 3 + self.global_rank = self.global_rank + self.global_group_size = self.role_maker._worker_num() + self.global_group_endpoints = self.endpoints[:] logging.info("Using Sharing as Outer parallelism mode !") self.dp_ring_id = -1 self.dp_rank = -1