From 75644cafed79035f6c987cb0bd3c5c5cb3e874b8 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Thu, 11 Mar 2021 17:52:13 +0800 Subject: [PATCH] update --- .../meta_optimizers/sharding/fp16_helper.py | 2 +- .../fleet/meta_optimizers/sharding/utils.py | 2 +- .../meta_optimizers/sharding_optimizer.py | 18 ++++++++---------- python/paddle/fluid/optimizer.py | 3 +++ 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index 5eb60fd535a..ff05e039f5c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -105,7 +105,7 @@ class FP16Utils(object): reversed_x = [] reversed_x_paramname = [] for input_name in op.desc.input('X'): - param_name = input_name.strip("@GRAD") + param_name = input_name.strip("@GRAD@MERGED") if param_name not in shard.global_params: raise ValueError( "Input 'X' of check_finite_and_unscale must" diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 744a7b85778..6081733f395 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -357,7 +357,7 @@ def get_grad_device(grad_name, shard): base_name = None # mind the traversal order possible_suffixes = [ - '.cast_fp16@GRAD_0', '.cast_fp16@GRAD', '@GRAD_0', '@GRAD' + '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD' ] for suffix in possible_suffixes: if suffix in grad_name: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 6c00aa9fd45..176ab170d68 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -103,8 +103,6 @@ class ShardingOptimizer(MetaOptimizerBase): self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ "pp_allreduce_in_optimize"] - self.optimize_offload = self.user_defined_strategy.sharding_configs[ - "optimize_offload"] if self.inner_opt is None: raise ValueError( @@ -947,8 +945,9 @@ class ShardingOptimizer(MetaOptimizerBase): ] self.pp_group_size = self.pipeline_nodes self.pp_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) if - (idx % self.sharding_group_size) == self.sharding_rank + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size + ) == self.sharding_rank ] else: self.mp_group_id = 0 @@ -972,12 +971,11 @@ class ShardingOptimizer(MetaOptimizerBase): self._inner_parallelism_size * self.sharding_group_size) self.megatron_rank = self.global_rank % self._inner_parallelism_size self.sharding_group_endpoints = [ - ep for idx, ep in enumerate(self.endpoints) if - (idx // - (self._inner_parallelism_size * - self.sharding_group_size)) == self.sharding_group_id - and - idx % self._inner_parallelism_size == self.megatron_rank + ep for idx, ep in enumerate(self.endpoints) + if (idx // (self._inner_parallelism_size * + self.sharding_group_size) + ) == self.sharding_group_id and idx % + self._inner_parallelism_size == self.megatron_rank ] print("sharding_endpoint:", self.sharding_group_endpoints) print("sharding_rank:", self.sharding_rank) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index bc0532a4ee0..bd31f8d20a8 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4898,6 +4898,7 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Backward, }) offset += 1 + merged_gradient_names.append(merged_param_grad_name) else: # cast gradient to fp32 to accumulate to merged gradient cast_grad_var_name = param_grad_name + '@TMP' @@ -4928,6 +4929,8 @@ class PipelineOptimizer(object): self._op_role_var_key: op_role_var }) offset += 1 + merged_gradient_names.append(merged_param_grad_name) + return merged_gradient_names def _add_sub_blocks(self, main_block, program_list): main_program = main_block.program -- GitLab