diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index b274f7b9b8487b7302b9e27428b33058aa8a65ca..9538364bf894eb5abdbed771c43520ad2378ecb8 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -143,7 +143,8 @@ class DataParallelOptimizationPass(PassBase): def _could_be_prune(self): - return self._support_rescale_grad or self._all_dp_groups_same_degree() + return self.dist_context._gradient_scale and ( + self._support_rescale_grad or self._all_dp_groups_same_degree()) def _all_dp_groups_same_degree(self): return len(