From 997651ab78e60758c085b34354198cdec4003e4e Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Mon, 8 Feb 2021 21:26:14 +0800 Subject: [PATCH] update, test=develop --- paddle/fluid/framework/section_worker.cc | 4 +- .../meta_optimizers/sharding_optimizer.py | 164 +++++++++++------- 2 files changed, 105 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index fb75a26e5b5..13736c49e1e 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -99,8 +99,8 @@ void SectionWorker::TrainFiles() { VLOG(3) << "Update: running op " << op->Type(); op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, - gc.get()); + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); } } } diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 2e544ceb718..0fe10dd839e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -40,6 +40,7 @@ class ShardingOptimizer(MetaOptimizerBase): "LarsOptimizer", "LambOptimizer", "ModelParallelOptimizer", + "PipelineOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -98,14 +99,14 @@ class ShardingOptimizer(MetaOptimizerBase): pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) main_program = loss.block.program main_program._pipeline_opt = dict() - pp_rank = self.role_maker._worker_index( - ) // self.user_defined_strategy.sharding_configs[ - 'sharding_group_size'] + pp_rank = self.role_maker._worker_index() // ( + self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] * self._inner_parallelism_size) main_program._pipeline_opt['local_rank'] = pp_rank main_program._pipeline_opt[ 'global_rank'] = self.role_maker._worker_index() main_program._pipeline_opt['use_sharding'] = True - main_program._pipeline_opt['ring_id'] = 1 + main_program._pipeline_opt['ring_id'] = 2 optimize_ops, params_grads, program_list = pp_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) self.pipeline_nodes = len(program_list) @@ -358,16 +359,19 @@ class ShardingOptimizer(MetaOptimizerBase): # config sharding & dp groups self._init_comm() # sharding + print("sharding_group_endpoints:", self.sharding_group_endpoints) + print("sharding_rank:", self.sharding_rank) + print("sharding_ring_id:", self.sharding_ring_id) self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, self.sharding_ring_id, True) # inner & outer model parallelism - if self._as_outer_parallelism: - self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True) + # if self._as_outer_parallelism: + # self._collective_helper._init_communicator( + # self._startup_program, self.current_endpoint, + # self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True) # dp if self.hybrid_dp: @@ -757,7 +761,7 @@ class ShardingOptimizer(MetaOptimizerBase): logging.info("Using Sharing&DP mode !") else: - if self._as_outer_parallelism: + if self._as_outer_parallelism and not self.use_pipeline: self.sharding_ring_id = 1 assert self.global_word_size > self._inner_parallelism_size, \ "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) @@ -801,75 +805,113 @@ class ShardingOptimizer(MetaOptimizerBase): # logging.info("megatron endpoints: {}".format( # magetron_endpoints)) if self.use_pipeline: - self.sharding_ring_id = 0 - self.sharding_group_size = self.user_defined_strategy.sharding_configs[ - 'sharding_group_size'] - self.sharding_rank = self.global_rank % self.sharding_group_size - assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num( - ) - self.pp_ring_id = 1 - self.pp_rank = self.global_rank // self.sharding_group_size - self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx // self.sharding_group_size) == self.pp_rank - ] - self.pp_group_size = self.pipeline_nodes - self.pp_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) - if (idx % self.sharding_group_size) == self.sharding_rank - ] + if self._inner_parallelism_size == 1: + self.sharding_ring_id = 0 + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + self.sharding_rank = self.global_rank % self.sharding_group_size + assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( + ) + self.pp_ring_id = 2 + self.pp_rank = self.global_rank // ( + self.sharding_group_size * self._inner_parallelism_size) + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self.sharding_group_size) == self.pp_rank + ] + self.pp_group_size = self.pipeline_nodes + self.pp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size + ) == self.sharding_rank + ] + else: + self.sharding_ring_id = 1 + self.pp_ring_id = 2 + # self.cards_per_node = 8 + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size + # self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size) + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self._inner_parallelism_size % + self.sharding_group_size) == self.sharding_rank + ] + assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( + ) + self.pp_rank = self.global_rank // ( + self.sharding_group_size * self._inner_parallelism_size) + offset = self.sharding_group_size * self._inner_parallelism_size + idx_with_pp_0 = self.global_rank % ( + self.sharding_group_size * self._inner_parallelism_size) + self.pp_group_endpoints = [] + for i in range(self.pipeline_nodes): + self.pp_group_endpoints.append(self.endpoints[ + idx_with_pp_0]) + idx_with_pp_0 += offset + + #self.pp_group_endpoints = [ + # ep for idx, ep in enumerate(self.endpoints) + # if (idx % self.sharding_group_size) == self.sharding_rank + #] + self.mp_group_id = 1 + self.mp_rank = self.global_rank + self.mp_group_size = self.role_maker._worker_num() + self.mp_group_endpoints = self.endpoints[:] + logging.info("Using Sharing as Outer parallelism mode !") self.dp_ring_id = -1 self.dp_rank = -1 self.dp_group_size = None self.dp_group_endpoints = None logging.info("Using Sharing with pipeline !") - else: - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank - self.sharding_group_size = self.role_maker._worker_num() - self.sharding_group_endpoints = self.endpoints + #else: + # self.sharding_ring_id = 0 + # self.sharding_rank = self.global_rank + # self.sharding_group_size = self.role_maker._worker_num() + # self.sharding_group_endpoints = self.endpoints - # sharding parallelism is the only model parallelism in the current setting - self.mp_group_id = self.sharding_ring_id - self.mp_rank = self.sharding_rank - self.mp_group_size = self.sharding_group_size - self.mp_group_endpoints = self.sharding_group_endpoints[:] + # # sharding parallelism is the only model parallelism in the current setting + # self.mp_group_id = self.sharding_ring_id + # self.mp_rank = self.sharding_rank + # self.mp_group_size = self.sharding_group_size + # self.mp_group_endpoints = self.sharding_group_endpoints[:] - logging.info("Using Sharing alone mode !") + # logging.info("Using Sharing alone mode !") self.dp_ring_id = -1 self.dp_rank = -1 self.dp_group_size = None self.dp_group_endpoints = None - self.pp_ring_id = -1 - self.pp_rank = -1 - self.pp_group_size = None - self.pp_group_endpoints = None - self.dp_ring_id = -1 - self.dp_rank = -1 - self.dp_group_size = None - self.dp_group_endpoints = None + #self.pp_ring_id = -1 + #self.pp_rank = -1 + #self.pp_group_size = None + #self.pp_group_endpoints = None + #self.dp_ring_id = -1 + #self.dp_rank = -1 + #self.dp_group_size = None + #self.dp_group_endpoints = None logging.info("Using Sharing alone mode !") - logging.info("global word size: {}".format(self.global_word_size)) - logging.info("global rank: {}".format(self.global_rank)) - logging.info("sharding group_size: {}".format(self.sharding_group_size)) - logging.info("sharding rank: {}".format(self.sharding_rank)) - logging.info("current model parallelism group_size: {}".format( - self.mp_group_size)) - logging.info("current model parallelism rank: {}".format(self.mp_rank)) - logging.info("dp group size: {}".format(self.dp_group_size)) - logging.info("dp rank: {}".format(self.dp_rank)) - logging.info("current endpoint: {}".format(self.current_endpoint)) - logging.info("global word endpoints: {}".format(self.endpoints)) - logging.info("sharding group endpoints: {}".format( - self.sharding_group_endpoints)) - logging.info("current model parallelism group endpoints: {}".format( - self.mp_group_endpoints)) - logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) + #logging.info("global word size: {}".format(self.global_word_size)) + #logging.info("global rank: {}".format(self.global_rank)) + #logging.info("sharding group_size: {}".format(self.sharding_group_size)) + #logging.info("sharding rank: {}".format(self.sharding_rank)) + #logging.info("current model parallelism group_size: {}".format( + # self.mp_group_size)) + #logging.info("current model parallelism rank: {}".format(self.mp_rank)) + #logging.info("dp group size: {}".format(self.dp_group_size)) + #logging.info("dp rank: {}".format(self.dp_rank)) + #logging.info("current endpoint: {}".format(self.current_endpoint)) + #logging.info("global word endpoints: {}".format(self.endpoints)) + #logging.info("sharding group endpoints: {}".format( + # self.sharding_group_endpoints)) + #logging.info("current model parallelism group endpoints: {}".format( + # self.mp_group_endpoints)) + #logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) return -- GitLab