From 8bb509d56e7637f8c8b53ea599f1f062a8a5e780 Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:53:36 +0800 Subject: [PATCH] fix_stage3_fp16 (#39171) --- .../fleet/meta_parallel/sharding/sharding_stage3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 41c6f92230a..9d7bd937411 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -393,6 +393,7 @@ class ShardingStage3(nn.Layer): else: param.bw_storage.scale_(scale=self._world_size_scaling) param.fw_storage = _VarBaseWrapper(param) + assert param.fw_storage.grad is None param.fw_storage._copy_gradient_from(param.bw_storage) update_list.append(param) return update_list @@ -495,10 +496,9 @@ class ShardingStage3(nn.Layer): def _redefine_opt_step(self): params_slice_func = self._update_params_slice opt_step = self._optim.step - update_scaler = self._optim.update_scaler def _opt_step(self): - if not update_scaler: + if not self.update_scaler: params_slice_func() if self.offload: with device_guard(device="cpu"): -- GitLab