From acdaa4fb3435d0016f6883660540a58b07008335 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 12 Oct 2022 19:32:15 +0800 Subject: [PATCH] bugfix (#46921) --- .../auto_parallel_data_parallel_optimization.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index da0c46a8eb..8470aa5109 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -82,9 +82,11 @@ class DataParallelOptimizationPass(PassBase): with paddle.static.program_guard(main_program, startup_program): self._analyze_program() - self._prune_grad_scaling() - self._calc_comm_overlap() - grad_group = self._fuse_allreduce() + + if self.is_data_parallel_applied(): + self._prune_grad_scaling() + self._calc_comm_overlap() + grad_group = self._fuse_allreduce() # self.summary(grad_group) @@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase): ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( not_synchronized_grads) + def is_data_parallel_applied(self): + return len(self._group_to_grad_name_map) > 0 + def _could_be_prune(self): return self.dist_context.gradient_scale and ( -- GitLab