diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 39e92f88780285b2b430b774c83f5b665cbad998..f13739960b38a1a43725116ced6d01f9b6bbec1a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -210,9 +210,10 @@ class GroupShardedStage2(nn.Layer): scale=self._world_size_scaling) # Scale grads of params - for param in self._trainable_params: - if param.name in self._param_grads and param.grad is not None: - param.grad.scale_(scale=self._world_size_scaling) + with paddle.no_grad(): + for param in self._trainable_params: + if param.name in self._param_grads and param.grad is not None: + param.grad.scale_(scale=self._world_size_scaling) # param._reset_grad_inplace_version(True) # Scale grads of master params with offload strategy