diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index c34ec8f45e15f32c73e2072fe8d623be3edc54e8..b9967ca202c80c09db2dcc3a4ab4c47b375180b6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -20,6 +20,7 @@ import paddle from paddle import _legacy_C_ops from paddle.distributed.parallel import _split_tensors from paddle.fluid import core +from paddle.framework import base as imperative_base __all__ = [] @@ -165,6 +166,7 @@ class FusedAllReduceBuffer: if self._all_params_checked_in: self._fused_allreduce_grads() + @imperative_base.no_grad def _fused_allreduce_grads(self): assert self._all_params_checked_in flattened_vars = [] @@ -188,6 +190,7 @@ class FusedAllReduceBuffer: ) ) + @imperative_base.no_grad def scale_and_split_grads(self): for task in self._tasks: task.wait()