未验证 提交 b292dfb8 编写于 作者: B Baibaifan 提交者: GitHub

optimize sharding stage3 offload (#39397)

上级 c5affb78
......@@ -86,7 +86,7 @@ class ShardingStage3(nn.Layer):
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
self._segment_size = segment_size if not offload else 0
self._segment_size = segment_size
global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
......@@ -191,8 +191,23 @@ class ShardingStage3(nn.Layer):
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
if not self._offload:
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
else:
for param in list(self._unslice_params):
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = param.cuda(DEV_ID)
param._clear()
if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value)
tmp_var._share_buffer_to(param)
tmp_var._clear()
for grad_storage in self._grad_storages.values():
grad_storage.manumal_relase()
grad_storage.rebuild()
# Update param memery slice
def _update_params_slice(self):
......@@ -455,6 +470,21 @@ class ShardingStage3(nn.Layer):
group=self._group,
use_calc_stream=True)
if self._offload:
for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True)
tmp_var._share_buffer_to(param)
tmp_var._clear()
for grad_storage in self._grad_storages.values():
for p in grad_storage._params:
tmp_g = _device2cpu(p.grad, convert_dtype=True)
p.clear_gradient(False)
p._gradient_set_empty(False)
p._copy_gradient_from(tmp_g)
tmp_g._clear()
grad_storage.buffer._clear()
return update_list
def get_all_parameters(self, convert2cpu=False):
......
......@@ -131,14 +131,15 @@ class ShardingClipGrad:
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
if getattr(p, 'need_clip', True) is False or g is None:
continue
origin_state = g.stop_gradient
g.stop_gradient = True
if p.dtype == paddle.float16:
g.scale_(clip_var_fp16)
else:
g.scale_(clip_var)
g.stop_gradient = origin_state
p._reset_grad_inplace_version(True)
return params_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册