From 07c729aa6e1183afc7e5ff113f07b6a6ae98b79c Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 13 Jul 2022 11:16:04 +0800 Subject: [PATCH] [Eager] Fix sharding in eager (#44271) * fix sharding in eager * support eager sharding --- .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 39e92f8878..f13739960b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -210,9 +210,10 @@ class GroupShardedStage2(nn.Layer): scale=self._world_size_scaling) # Scale grads of params - for param in self._trainable_params: - if param.name in self._param_grads and param.grad is not None: - param.grad.scale_(scale=self._world_size_scaling) + with paddle.no_grad(): + for param in self._trainable_params: + if param.name in self._param_grads and param.grad is not None: + param.grad.scale_(scale=self._world_size_scaling) # param._reset_grad_inplace_version(True) # Scale grads of master params with offload strategy -- GitLab