未验证 提交 cdb6b746 编写于 作者: W wuhuachaocoding 提交者: GitHub

Sharding3 dp final base (#49144)

上级 b0e9e48d
......@@ -85,6 +85,7 @@ class GroupShardedStage3(nn.Layer):
pertrain_sync_models=True,
offload=False,
sync_comm=False,
dp_group=None,
):
super().__init__()
......@@ -120,6 +121,7 @@ class GroupShardedStage3(nn.Layer):
if group is None
else group
)
self._dp_group = dp_group
self._world_size_scaling = 1.0 / self._group.nranks
assert (
self._group.nranks > 1
......@@ -201,6 +203,13 @@ class GroupShardedStage3(nn.Layer):
dist.broadcast(
p, src=self._global_root_rank, group=self._group, sync_op=True
)
if self._dp_group is not None and self._dp_group.nranks > 1:
dist.broadcast(
p,
src=self._dp_group.ranks[0],
group=self._dp_group,
sync_op=True,
)
def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0
......@@ -502,6 +511,13 @@ class GroupShardedStage3(nn.Layer):
dist.broadcast(
buffer, self._global_root_rank, self._group, sync_op=True
)
if self._dp_group is not None and self._dp_group.nranks > 1:
dist.broadcast(
buffer,
self._dp_group.ranks[0],
self._dp_group,
sync_op=True,
)
def __getattr__(self, name):
"""Forward missing attributes to wrapped layer."""
......@@ -528,12 +544,7 @@ class GroupShardedStage3(nn.Layer):
assert hasattr(
param, "fw_storage"
), "Find {} don't have fw_storage attribute".format(param.name)
# Gradient average
if self._offload:
with device_guard():
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
......@@ -543,6 +554,12 @@ class GroupShardedStage3(nn.Layer):
for grad_storage in self._grad_storages.values():
grad_storage.buffer.scale_(scale=self._world_size_scaling)
dist.all_reduce(tensor=grad_storage.buffer, group=self._group)
if self._dp_group is not None and self._dp_group.nranks > 1:
grad_storage.buffer.scale_(scale=(1.0 / self._dp_group.nranks))
dist.all_reduce(
tensor=grad_storage.buffer, group=self._dp_group
)
if self._offload:
for param in list(self._unslice_params):
param._clear_data()
......@@ -609,7 +626,11 @@ class GroupShardedStage3(nn.Layer):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now
full_grad.scale_(scale=self._world_size_scaling)
dist.all_reduce(tensor=full_grad, group=self._group)
if self._dp_group is not None and self._dp_group.nranks > 1:
full_grad.scale_(scale=1.0 / self._dp_group.nranks)
dist.all_reduce(tensor=full_grad, group=self._dp_group)
start, end = self._param2buffer[param.name][self._rank]
if param.bw_storage is None:
......
......@@ -119,11 +119,7 @@ class GroupShardedClipGrad:
global_unslice_fp32 = paddle.sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = (
global_norm_fp16
+ global_norm_fp32
+ 1.0 / self._group.nranks * global_unslice_var
)
global_norm_var = global_norm_fp16 + global_norm_fp32
# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
......@@ -133,7 +129,7 @@ class GroupShardedClipGrad:
with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var = paddle.sqrt(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var + global_unslice_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
)
......@@ -150,9 +146,9 @@ class GroupShardedClipGrad:
origin_state = g.stop_gradient
g.stop_gradient = True
if p.dtype == paddle.float16:
g.scale_(clip_var_fp16.item())
g.scale_(clip_var_fp16)
else:
g.scale_(clip_var.item())
g.scale_(clip_var)
g.stop_gradient = origin_state
# p._reset_grad_inplace_version(True)
......
......@@ -79,7 +79,7 @@ def group_sharded_parallel(
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
Returns:
model: A wrapper for group sharded given model.
......@@ -192,6 +192,7 @@ def group_sharded_parallel(
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
dp_group=dp_group,
device=device,
)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册