未验证 提交 53e50383 编写于 作者: B Baibaifan 提交者: GitHub

[Dygraph]fix_sharding3_offload (#42955)

* fix_sharding3_offload

* fix_fp16dtype_bug
上级 07dab9da
...@@ -205,7 +205,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -205,7 +205,7 @@ class GroupShardedStage3(nn.Layer):
for param in list(self._unslice_params): for param in list(self._unslice_params):
param.clear_gradient(False) param.clear_gradient(False)
tmp_var = param.cuda(DEV_ID) tmp_var = param.cuda(DEV_ID)
param._clear_data()
if tmp_var.dtype == Type.fp32.value and param2dtype[ if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value: param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value) tmp_var = paddle.cast(tmp_var, Type.fp16.value)
...@@ -272,6 +272,8 @@ class GroupShardedStage3(nn.Layer): ...@@ -272,6 +272,8 @@ class GroupShardedStage3(nn.Layer):
master_tensor = paddle.cast(param, Type.fp32.value) master_tensor = paddle.cast(param, Type.fp32.value)
master_tensor.name = param.name master_tensor.name = param.name
self._optim._master_weights[param.name] = master_tensor self._optim._master_weights[param.name] = master_tensor
if self._offload:
param.master_weight = paddle.cast(param, Type.fp32.value).cpu()
param2dtype[param.name] = param.dtype param2dtype[param.name] = param.dtype
p_align = self._param2align(param) p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align self._unslice_params2align[param.name] = p_align
...@@ -369,7 +371,6 @@ class GroupShardedStage3(nn.Layer): ...@@ -369,7 +371,6 @@ class GroupShardedStage3(nn.Layer):
tmp_var.get_tensor().set(param_cpu.get_tensor(), core.CPUPlace()) tmp_var.get_tensor().set(param_cpu.get_tensor(), core.CPUPlace())
del tmp_var del tmp_var
param.get_tensor()._set_dims(param_shape) param.get_tensor()._set_dims(param_shape)
param._clear_data()
# Current rank param_storage # Current rank param_storage
if self._offload: if self._offload:
...@@ -379,6 +380,9 @@ class GroupShardedStage3(nn.Layer): ...@@ -379,6 +380,9 @@ class GroupShardedStage3(nn.Layer):
value=tmp_tensor, value=tmp_tensor,
place=core.CPUPlace(), place=core.CPUPlace(),
name="slice@" + param.name) name="slice@" + param.name)
with device_guard():
param.master_weight = paddle.cast(param.fw_storage,
Type.fp32.value)
else: else:
param.fw_storage = core.eager.Tensor( param.fw_storage = core.eager.Tensor(
value=buffer._slice(start, end), name="slice@" + param.name) value=buffer._slice(start, end), name="slice@" + param.name)
...@@ -389,6 +393,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -389,6 +393,7 @@ class GroupShardedStage3(nn.Layer):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
master_tensor.name = param.name master_tensor.name = param.name
self._optim._master_weights[param.fw_storage.name] = master_tensor self._optim._master_weights[param.fw_storage.name] = master_tensor
param._clear_data()
def _register_forward_hooks(self, layer): def _register_forward_hooks(self, layer):
""" """
...@@ -480,9 +485,8 @@ class GroupShardedStage3(nn.Layer): ...@@ -480,9 +485,8 @@ class GroupShardedStage3(nn.Layer):
collective.all_reduce(tensor=grad_storage.buffer, group=self._group) collective.all_reduce(tensor=grad_storage.buffer, group=self._group)
if self._offload: if self._offload:
for param in list(self._unslice_params): for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True) param._clear_data()
tmp_var._share_buffer_to(param) param.master_weight._share_buffer_to(param)
del tmp_var
for grad_storage in self._grad_storages.values(): for grad_storage in self._grad_storages.values():
for p in grad_storage._params: for p in grad_storage._params:
...@@ -568,7 +572,8 @@ class GroupShardedStage3(nn.Layer): ...@@ -568,7 +572,8 @@ class GroupShardedStage3(nn.Layer):
del self._task_flow.full_param[param.name] del self._task_flow.full_param[param.name]
if self._offload: if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True) param.fw_storage._clear_data()
param.master_weight._share_buffer_to(param.fw_storage)
return allreduce_ return allreduce_
...@@ -856,6 +861,7 @@ def _PartitionParam(param): ...@@ -856,6 +861,7 @@ def _PartitionParam(param):
if not hasattr(param, "fw_storage"): if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None) setattr(param, "fw_storage", None)
setattr(param, "bw_storage", None) setattr(param, "bw_storage", None)
setattr(param, "master_weight", None)
setattr(param, "status", "all") setattr(param, "status", "all")
setattr(param, "use_count", 0) setattr(param, "use_count", 0)
return param return param
...@@ -864,6 +870,7 @@ def _PartitionParam(param): ...@@ -864,6 +870,7 @@ def _PartitionParam(param):
def _UnsliceParam(param): def _UnsliceParam(param):
if not hasattr(param, "unslice"): if not hasattr(param, "unslice"):
setattr(param, "unslice", True) setattr(param, "unslice", True)
setattr(param, "master_weight", None)
return param return param
......
...@@ -199,7 +199,7 @@ class ShardingStage3(nn.Layer): ...@@ -199,7 +199,7 @@ class ShardingStage3(nn.Layer):
param.clear_gradient(False) param.clear_gradient(False)
param._gradient_set_empty(False) param._gradient_set_empty(False)
tmp_var = param.cuda(DEV_ID) tmp_var = param.cuda(DEV_ID)
param._clear()
if tmp_var.dtype == Type.fp32.value and param2dtype[ if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value: param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value) tmp_var = paddle.cast(tmp_var, Type.fp16.value)
...@@ -220,19 +220,14 @@ class ShardingStage3(nn.Layer): ...@@ -220,19 +220,14 @@ class ShardingStage3(nn.Layer):
self._optim._param_groups = slice_params + list( self._optim._param_groups = slice_params + list(
self._unslice_params) self._unslice_params)
else: else:
params_name_list = list(map(lambda p: p.name, update_list))
fw_storage_name_list = list(
map(lambda p: p.fw_storage.name, update_list))
for param_group in self._optim._param_groups: for param_group in self._optim._param_groups:
p_group = [] p_group = []
for p in param_group['params']: for p in param_group['params']:
if p.name in params_name_list: if hasattr(p, "fw_storage"):
p_group.append(p.fw_storage) p_group.append(p.fw_storage)
elif p.name in fw_storage_name_list: else:
p_group.append(update_list[fw_storage_name_list.index(
p.name)].fw_storage)
elif p in self._unslice_params:
p_group.append(p) p_group.append(p)
param_group['params'] = p_group param_group['params'] = p_group
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -268,6 +263,8 @@ class ShardingStage3(nn.Layer): ...@@ -268,6 +263,8 @@ class ShardingStage3(nn.Layer):
if param.dtype == Type.fp16.value and not self._offload: if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.name] = paddle.cast( self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value) param, Type.fp32.value)
if self._offload:
param.master_weight = paddle.cast(param, Type.fp32.value).cpu()
param2dtype[param.name] = param.dtype param2dtype[param.name] = param.dtype
p_align = self._param2align(param) p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align self._unslice_params2align[param.name] = p_align
...@@ -335,11 +332,12 @@ class ShardingStage3(nn.Layer): ...@@ -335,11 +332,12 @@ class ShardingStage3(nn.Layer):
self._param2buffer[param.name].append( self._param2buffer[param.name].append(
(rank_ * pre_buffer, (rank_ + 1) * pre_buffer)) (rank_ * pre_buffer, (rank_ + 1) * pre_buffer))
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
# Record param's dtype # Record param's dtype
param2dtype[param.name] = param.dtype param2dtype[param.name] = param.dtype
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
def _param_storage(self, param, buffer_size): def _param_storage(self, param, buffer_size):
""" """
This is a function to simplify the handling of parameter InternalStorages. This is a function to simplify the handling of parameter InternalStorages.
...@@ -365,13 +363,15 @@ class ShardingStage3(nn.Layer): ...@@ -365,13 +363,15 @@ class ShardingStage3(nn.Layer):
tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(), tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(),
core.CPUPlace()) core.CPUPlace())
param.value().get_tensor()._set_dims(param_shape) param.value().get_tensor()._set_dims(param_shape)
param._clear()
# Current rank param_storage # Current rank param_storage
if self._offload: if self._offload:
param.fw_storage = core.VarBase( param.fw_storage = core.VarBase(
buffer._slice(start, end), buffer._slice(start, end),
core.CPUPlace(), "slice@" + param.name) core.CPUPlace(), "slice@" + param.name)
with device_guard(device="cpu"):
param.master_weight = paddle.cast(param.fw_storage,
Type.fp32.value)
else: else:
param.fw_storage = core.VarBase( param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name) buffer._slice(start, end), "slice@" + param.name)
...@@ -381,6 +381,7 @@ class ShardingStage3(nn.Layer): ...@@ -381,6 +381,7 @@ class ShardingStage3(nn.Layer):
if param.dtype == Type.fp16.value and not self._offload: if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.fw_storage.name] = paddle.cast( self._optim._master_weights[param.fw_storage.name] = paddle.cast(
param.fw_storage, Type.fp32.value) param.fw_storage, Type.fp32.value)
param._clear()
def _register_forward_hooks(self, layer): def _register_forward_hooks(self, layer):
""" """
...@@ -482,9 +483,8 @@ class ShardingStage3(nn.Layer): ...@@ -482,9 +483,8 @@ class ShardingStage3(nn.Layer):
if self._offload: if self._offload:
for param in list(self._unslice_params): for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True) param._clear()
tmp_var._share_buffer_to(param) param.master_weight._share_buffer_to(param)
tmp_var._clear()
for grad_storage in self._grad_storages.values(): for grad_storage in self._grad_storages.values():
for p in grad_storage._params: for p in grad_storage._params:
...@@ -553,6 +553,7 @@ class ShardingStage3(nn.Layer): ...@@ -553,6 +553,7 @@ class ShardingStage3(nn.Layer):
cpu_grad = _device2cpu( cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end)) core.VarBase(full_grad._slice(start, end))
.detach().clone(), True) .detach().clone(), True)
with device_guard(device="cpu"):
param.bw_storage = paddle.add(param.bw_storage, param.bw_storage = paddle.add(param.bw_storage,
cpu_grad) cpu_grad)
else: else:
...@@ -581,7 +582,8 @@ class ShardingStage3(nn.Layer): ...@@ -581,7 +582,8 @@ class ShardingStage3(nn.Layer):
tmp_var._clear() tmp_var._clear()
if self._offload: if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True) param.fw_storage._clear()
param.master_weight._share_buffer_to(param.fw_storage)
return allreduce_ return allreduce_
...@@ -869,6 +871,7 @@ def _PartitionParam(param): ...@@ -869,6 +871,7 @@ def _PartitionParam(param):
if not hasattr(param, "fw_storage"): if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None) setattr(param, "fw_storage", None)
setattr(param, "bw_storage", None) setattr(param, "bw_storage", None)
setattr(param, "master_weight", None)
setattr(param, "status", "all") setattr(param, "status", "all")
setattr(param, "use_count", 0) setattr(param, "use_count", 0)
return param return param
...@@ -877,6 +880,7 @@ def _PartitionParam(param): ...@@ -877,6 +880,7 @@ def _PartitionParam(param):
def _UnsliceParam(param): def _UnsliceParam(param):
if not hasattr(param, "unslice"): if not hasattr(param, "unslice"):
setattr(param, "unslice", True) setattr(param, "unslice", True)
setattr(param, "master_weight", None)
return param return param
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册