未验证 提交 acdaa4fb 编写于 作者: J JZ-LIANG 提交者: GitHub

bugfix (#46921)

上级 686fa07a
...@@ -82,6 +82,8 @@ class DataParallelOptimizationPass(PassBase): ...@@ -82,6 +82,8 @@ class DataParallelOptimizationPass(PassBase):
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
if self.is_data_parallel_applied():
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
grad_group = self._fuse_allreduce() grad_group = self._fuse_allreduce()
...@@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase):
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads) not_synchronized_grads)
def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context.gradient_scale and ( return self.dist_context.gradient_scale and (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册