未验证 提交 53e50383 编写于 作者: B Baibaifan 提交者: GitHub

[Dygraph]fix_sharding3_offload (#42955)

* fix_sharding3_offload

* fix_fp16dtype_bug
上级 07dab9da
......@@ -205,7 +205,7 @@ class GroupShardedStage3(nn.Layer):
for param in list(self._unslice_params):
param.clear_gradient(False)
tmp_var = param.cuda(DEV_ID)
param._clear_data()
if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value)
......@@ -272,6 +272,8 @@ class GroupShardedStage3(nn.Layer):
master_tensor = paddle.cast(param, Type.fp32.value)
master_tensor.name = param.name
self._optim._master_weights[param.name] = master_tensor
if self._offload:
param.master_weight = paddle.cast(param, Type.fp32.value).cpu()
param2dtype[param.name] = param.dtype
p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align
......@@ -369,7 +371,6 @@ class GroupShardedStage3(nn.Layer):
tmp_var.get_tensor().set(param_cpu.get_tensor(), core.CPUPlace())
del tmp_var
param.get_tensor()._set_dims(param_shape)
param._clear_data()
# Current rank param_storage
if self._offload:
......@@ -379,6 +380,9 @@ class GroupShardedStage3(nn.Layer):
value=tmp_tensor,
place=core.CPUPlace(),
name="slice@" + param.name)
with device_guard():
param.master_weight = paddle.cast(param.fw_storage,
Type.fp32.value)
else:
param.fw_storage = core.eager.Tensor(
value=buffer._slice(start, end), name="slice@" + param.name)
......@@ -389,6 +393,7 @@ class GroupShardedStage3(nn.Layer):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
master_tensor.name = param.name
self._optim._master_weights[param.fw_storage.name] = master_tensor
param._clear_data()
def _register_forward_hooks(self, layer):
"""
......@@ -480,9 +485,8 @@ class GroupShardedStage3(nn.Layer):
collective.all_reduce(tensor=grad_storage.buffer, group=self._group)
if self._offload:
for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True)
tmp_var._share_buffer_to(param)
del tmp_var
param._clear_data()
param.master_weight._share_buffer_to(param)
for grad_storage in self._grad_storages.values():
for p in grad_storage._params:
......@@ -568,7 +572,8 @@ class GroupShardedStage3(nn.Layer):
del self._task_flow.full_param[param.name]
if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True)
param.fw_storage._clear_data()
param.master_weight._share_buffer_to(param.fw_storage)
return allreduce_
......@@ -856,6 +861,7 @@ def _PartitionParam(param):
if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None)
setattr(param, "bw_storage", None)
setattr(param, "master_weight", None)
setattr(param, "status", "all")
setattr(param, "use_count", 0)
return param
......@@ -864,6 +870,7 @@ def _PartitionParam(param):
def _UnsliceParam(param):
if not hasattr(param, "unslice"):
setattr(param, "unslice", True)
setattr(param, "master_weight", None)
return param
......
......@@ -199,7 +199,7 @@ class ShardingStage3(nn.Layer):
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = param.cuda(DEV_ID)
param._clear()
if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value)
......@@ -220,19 +220,14 @@ class ShardingStage3(nn.Layer):
self._optim._param_groups = slice_params + list(
self._unslice_params)
else:
params_name_list = list(map(lambda p: p.name, update_list))
fw_storage_name_list = list(
map(lambda p: p.fw_storage.name, update_list))
for param_group in self._optim._param_groups:
p_group = []
for p in param_group['params']:
if p.name in params_name_list:
if hasattr(p, "fw_storage"):
p_group.append(p.fw_storage)
elif p.name in fw_storage_name_list:
p_group.append(update_list[fw_storage_name_list.index(
p.name)].fw_storage)
elif p in self._unslice_params:
else:
p_group.append(p)
param_group['params'] = p_group
def forward(self, *inputs, **kwargs):
......@@ -268,6 +263,8 @@ class ShardingStage3(nn.Layer):
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
if self._offload:
param.master_weight = paddle.cast(param, Type.fp32.value).cpu()
param2dtype[param.name] = param.dtype
p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align
......@@ -335,11 +332,12 @@ class ShardingStage3(nn.Layer):
self._param2buffer[param.name].append(
(rank_ * pre_buffer, (rank_ + 1) * pre_buffer))
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
# Record param's dtype
param2dtype[param.name] = param.dtype
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
def _param_storage(self, param, buffer_size):
"""
This is a function to simplify the handling of parameter InternalStorages.
......@@ -365,13 +363,15 @@ class ShardingStage3(nn.Layer):
tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(),
core.CPUPlace())
param.value().get_tensor()._set_dims(param_shape)
param._clear()
# Current rank param_storage
if self._offload:
param.fw_storage = core.VarBase(
buffer._slice(start, end),
core.CPUPlace(), "slice@" + param.name)
with device_guard(device="cpu"):
param.master_weight = paddle.cast(param.fw_storage,
Type.fp32.value)
else:
param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name)
......@@ -381,6 +381,7 @@ class ShardingStage3(nn.Layer):
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.fw_storage.name] = paddle.cast(
param.fw_storage, Type.fp32.value)
param._clear()
def _register_forward_hooks(self, layer):
"""
......@@ -482,9 +483,8 @@ class ShardingStage3(nn.Layer):
if self._offload:
for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True)
tmp_var._share_buffer_to(param)
tmp_var._clear()
param._clear()
param.master_weight._share_buffer_to(param)
for grad_storage in self._grad_storages.values():
for p in grad_storage._params:
......@@ -553,6 +553,7 @@ class ShardingStage3(nn.Layer):
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
with device_guard(device="cpu"):
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
......@@ -581,7 +582,8 @@ class ShardingStage3(nn.Layer):
tmp_var._clear()
if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True)
param.fw_storage._clear()
param.master_weight._share_buffer_to(param.fw_storage)
return allreduce_
......@@ -869,6 +871,7 @@ def _PartitionParam(param):
if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None)
setattr(param, "bw_storage", None)
setattr(param, "master_weight", None)
setattr(param, "status", "all")
setattr(param, "use_count", 0)
return param
......@@ -877,6 +880,7 @@ def _PartitionParam(param):
def _UnsliceParam(param):
if not hasattr(param, "unslice"):
setattr(param, "unslice", True)
setattr(param, "master_weight", None)
return param
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册