diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 41c6f92230ab3e0e8de9aec0abdf920fad1ef232..9d7bd937411882541d9cb1311c241d3e84316c90 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -393,6 +393,7 @@ class ShardingStage3(nn.Layer): else: param.bw_storage.scale_(scale=self._world_size_scaling) param.fw_storage = _VarBaseWrapper(param) + assert param.fw_storage.grad is None param.fw_storage._copy_gradient_from(param.bw_storage) update_list.append(param) return update_list @@ -495,10 +496,9 @@ class ShardingStage3(nn.Layer): def _redefine_opt_step(self): params_slice_func = self._update_params_slice opt_step = self._optim.step - update_scaler = self._optim.update_scaler def _opt_step(self): - if not update_scaler: + if not self.update_scaler: params_slice_func() if self.offload: with device_guard(device="cpu"):