From 279ac753a5e9404178e6dce77fdd25a0d3bbf0e5 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 2 Jun 2023 20:56:28 +0800 Subject: [PATCH] fix split_tensor of dp_pp_comm_overlap (#54310) --- .../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index c34ec8f45e1..b9967ca202c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -20,6 +20,7 @@ import paddle from paddle import _legacy_C_ops from paddle.distributed.parallel import _split_tensors from paddle.fluid import core +from paddle.framework import base as imperative_base __all__ = [] @@ -165,6 +166,7 @@ class FusedAllReduceBuffer: if self._all_params_checked_in: self._fused_allreduce_grads() + @imperative_base.no_grad def _fused_allreduce_grads(self): assert self._all_params_checked_in flattened_vars = [] @@ -188,6 +190,7 @@ class FusedAllReduceBuffer: ) ) + @imperative_base.no_grad def scale_and_split_grads(self): for task in self._tasks: task.wait() -- GitLab