diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 8bbf42b72f2d6d8cb263d1e099044d8baf657a8c..00c72e28a6ffd38185ac4eaf1459485ab83bad3e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -86,7 +86,7 @@ class ShardingStage3(nn.Layer): self._offload = offload self._sync_comm = sync_comm # segmentation size - self._segment_size = segment_size if not offload else 0 + self._segment_size = segment_size global DEV DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device( @@ -191,8 +191,23 @@ class ShardingStage3(nn.Layer): param.fw_storage._gradient_set_empty(False) param.bw_storage._clear() # 2.Handle unslice param - for grad_storage in self._grad_storages.values(): - grad_storage.buffer.zero_() + if not self._offload: + for grad_storage in self._grad_storages.values(): + grad_storage.buffer.zero_() + else: + for param in list(self._unslice_params): + param.clear_gradient(False) + param._gradient_set_empty(False) + tmp_var = param.cuda(DEV_ID) + param._clear() + if tmp_var.dtype == Type.fp32.value and param2dtype[ + param.name] == Type.fp16.value: + tmp_var = paddle.cast(tmp_var, Type.fp16.value) + tmp_var._share_buffer_to(param) + tmp_var._clear() + for grad_storage in self._grad_storages.values(): + grad_storage.manumal_relase() + grad_storage.rebuild() # Update param memery slice def _update_params_slice(self): @@ -455,6 +470,21 @@ class ShardingStage3(nn.Layer): group=self._group, use_calc_stream=True) + if self._offload: + for param in list(self._unslice_params): + tmp_var = _device2cpu(param, convert_dtype=True) + tmp_var._share_buffer_to(param) + tmp_var._clear() + + for grad_storage in self._grad_storages.values(): + for p in grad_storage._params: + tmp_g = _device2cpu(p.grad, convert_dtype=True) + p.clear_gradient(False) + p._gradient_set_empty(False) + p._copy_gradient_from(tmp_g) + tmp_g._clear() + grad_storage.buffer._clear() + return update_list def get_all_parameters(self, convert2cpu=False): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index ee281a0a044f4d9e371d81b00df933f0f22ac26e..0a42b993d5bf2387d1110ae5478b6162ce175483 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -131,14 +131,15 @@ class ShardingClipGrad: clip_var_fp16 = paddle.cast(clip_var, paddle.float16) for p, g in params_grads: - if g is None: - continue - if getattr(p, 'need_clip', True) is False: + if getattr(p, 'need_clip', True) is False or g is None: continue + origin_state = g.stop_gradient + g.stop_gradient = True if p.dtype == paddle.float16: g.scale_(clip_var_fp16) else: g.scale_(clip_var) + g.stop_gradient = origin_state p._reset_grad_inplace_version(True) return params_grads