From 1de80a5e64e44f18cf5cd6142c7836a2d6670684 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Fri, 12 Mar 2021 14:32:04 +0800 Subject: [PATCH] update --- .../fleet/meta_optimizers/sharding_optimizer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 176ab170d68..ce57d812efe 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -213,6 +213,22 @@ class ShardingOptimizer(MetaOptimizerBase): # if self._shard.has_param(param_name): # param_list.append(param_name) #pp_optimizer._clear_gradients(main_block, param_list) + #accumulated_grad_names = pp_optimizer._accumulate_gradients( + # main_block) + # accumulated_grad_names = sorted(accumulated_grad_names) + if self.pp_allreduce_in_optimize: + print("persistable FP32 grad: ") + print(accumulated_grad_names) + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block) + insert_reduce_ops( + main_block, + first_optimize_op_index, + self.sharding_ring_id, + accumulated_grad_names, + self._shard, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) #if not self._shard.has_param(param_name): continue ##if not main_block.has_var(grad_name): continue #assert main_block.has_var(grad_name) -- GitLab