diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 176ab170d68e4a6dba000e63bc677b816c5b819f..ce57d812efef237e9726ef3073b9c45e5462394b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -213,6 +213,22 @@ class ShardingOptimizer(MetaOptimizerBase): # if self._shard.has_param(param_name): # param_list.append(param_name) #pp_optimizer._clear_gradients(main_block, param_list) + #accumulated_grad_names = pp_optimizer._accumulate_gradients( + # main_block) + # accumulated_grad_names = sorted(accumulated_grad_names) + if self.pp_allreduce_in_optimize: + print("persistable FP32 grad: ") + print(accumulated_grad_names) + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block) + insert_reduce_ops( + main_block, + first_optimize_op_index, + self.sharding_ring_id, + accumulated_grad_names, + self._shard, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) #if not self._shard.has_param(param_name): continue ##if not main_block.has_var(grad_name): continue #assert main_block.has_var(grad_name)