diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index da0c46a8eb121aa86cc561758141cd626aaf39aa..8470aa510996128119c2561313d49490897d74e8 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -82,9 +82,11 @@ class DataParallelOptimizationPass(PassBase): with paddle.static.program_guard(main_program, startup_program): self._analyze_program() - self._prune_grad_scaling() - self._calc_comm_overlap() - grad_group = self._fuse_allreduce() + + if self.is_data_parallel_applied(): + self._prune_grad_scaling() + self._calc_comm_overlap() + grad_group = self._fuse_allreduce() # self.summary(grad_group) @@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase): ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( not_synchronized_grads) + def is_data_parallel_applied(self): + return len(self._group_to_grad_name_map) > 0 + def _could_be_prune(self): return self.dist_context.gradient_scale and (