From eeca5ef6e810f18af11612db6a4cc56fb141bf7e Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 9 Mar 2021 11:51:40 +0800 Subject: [PATCH] update --- .../meta_optimizers/sharding_optimizer.py | 25 +++++++++++++------ python/paddle/fluid/optimizer.py | 7 +++--- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 755e390a0bf..e44b34255bd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core import paddle.fluid as fluid from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper -from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op, OpRole +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils @@ -208,7 +208,8 @@ class ShardingOptimizer(MetaOptimizerBase): #pp_optimizer._clear_gradients(main_block, param_list) accumulated_grad_names = pp_optimizer._accumulate_gradients( main_block) - accumulated_grad_names = sorted(accumulated_grad_names) + # accumulated_grad_names = sorted(accumulated_grad_names) + print("persistable FP32 grad: ") print(accumulated_grad_names) first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( main_block) @@ -218,7 +219,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_ring_id, accumulated_grad_names, self._shard, - OpRole.Optimize, + core.op_proto_and_checker_maker.OpRole.Optimize, use_calc_stream=True) #if not self._shard.has_param(param_name): continue ##if not main_block.has_var(grad_name): continue @@ -470,10 +471,20 @@ class ShardingOptimizer(MetaOptimizerBase): self._main_program.global_block()) def _wait(self, ): - endpoints = self.role_maker._get_trainer_endpoints() - current_endpoint = endpoints[self.role_maker._worker_index()] - if self.role_maker._worker_index() == 0: - self._collective_helper._wait(current_endpoint, endpoints) + # only the first parallelsm group that init nccl need to be wait. + if self._as_outer_parallelism: + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + else: + endpoints = self.sharding_group_endpoints[:] + current_endpoint = self.sharding_group_endpoints[self.sharding_rank] + + if self._as_outer_parallelism: + if self.role_maker._worker_index() == 0: + self._collective_helper._wait(current_endpoint, endpoints) + else: + if self.sharding_rank == 0: + self._collective_helper._wait(current_endpoint, endpoints) # def _wait(self, ): # # only the first parallelsm group that init nccl need to be wait. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4507635f8c4..347c27e6fcb 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4879,8 +4879,9 @@ class PipelineOptimizer(object): if '@BroadCast' in param_name: param_name = param_name[0:param_name.find('@BroadCast')] # clear gradient + assert param_name in self.origin_main_block.vars, "[{}] not in original main block".format( + param_name) param_grad_name = self._append_grad_suffix(param_name) - accumulated_grad_names.append(param_grad_name) if not block.has_var(param_grad_name): self._create_var( block, self.origin_main_block.vars[param_name], @@ -4925,7 +4926,7 @@ class PipelineOptimizer(object): #self._op_role_var_key: op_role_var }) #offset += 1 - # accumulated_gradient_names.append(param_grad_var.name) + accumulated_grad_names.append(param_grad_var.name) else: grad_name = op_role_var[i + 1] # with _0 suffix grad_var = block.vars[grad_name] @@ -4962,7 +4963,7 @@ class PipelineOptimizer(object): # self._op_role_var_key: op_role_var }) offset += 1 - # accumulated_gradient_names.append(param_grad_var.name) + accumulated_grad_names.append(param_grad_var.name) #real_grad_name = grad_name[0:grad_name.find( # '@GRAD')] + '@GRAD' #real_grad_var = block.vars[ -- GitLab