未验证 提交 cdb6b746 编写于 作者: W wuhuachaocoding 提交者: GitHub

Sharding3 dp final base (#49144)

上级 b0e9e48d
...@@ -85,6 +85,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -85,6 +85,7 @@ class GroupShardedStage3(nn.Layer):
pertrain_sync_models=True, pertrain_sync_models=True,
offload=False, offload=False,
sync_comm=False, sync_comm=False,
dp_group=None,
): ):
super().__init__() super().__init__()
...@@ -120,6 +121,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -120,6 +121,7 @@ class GroupShardedStage3(nn.Layer):
if group is None if group is None
else group else group
) )
self._dp_group = dp_group
self._world_size_scaling = 1.0 / self._group.nranks self._world_size_scaling = 1.0 / self._group.nranks
assert ( assert (
self._group.nranks > 1 self._group.nranks > 1
...@@ -201,6 +203,13 @@ class GroupShardedStage3(nn.Layer): ...@@ -201,6 +203,13 @@ class GroupShardedStage3(nn.Layer):
dist.broadcast( dist.broadcast(
p, src=self._global_root_rank, group=self._group, sync_op=True p, src=self._global_root_rank, group=self._group, sync_op=True
) )
if self._dp_group is not None and self._dp_group.nranks > 1:
dist.broadcast(
p,
src=self._dp_group.ranks[0],
group=self._dp_group,
sync_op=True,
)
def _clear_gradients(self): def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
...@@ -502,6 +511,13 @@ class GroupShardedStage3(nn.Layer): ...@@ -502,6 +511,13 @@ class GroupShardedStage3(nn.Layer):
dist.broadcast( dist.broadcast(
buffer, self._global_root_rank, self._group, sync_op=True buffer, self._global_root_rank, self._group, sync_op=True
) )
if self._dp_group is not None and self._dp_group.nranks > 1:
dist.broadcast(
buffer,
self._dp_group.ranks[0],
self._dp_group,
sync_op=True,
)
def __getattr__(self, name): def __getattr__(self, name):
"""Forward missing attributes to wrapped layer.""" """Forward missing attributes to wrapped layer."""
...@@ -528,12 +544,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -528,12 +544,7 @@ class GroupShardedStage3(nn.Layer):
assert hasattr( assert hasattr(
param, "fw_storage" param, "fw_storage"
), "Find {} don't have fw_storage attribute".format(param.name) ), "Find {} don't have fw_storage attribute".format(param.name)
# Gradient average
if self._offload:
with device_guard():
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param) param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage) param.fw_storage._copy_gradient_from(param.bw_storage)
...@@ -543,6 +554,12 @@ class GroupShardedStage3(nn.Layer): ...@@ -543,6 +554,12 @@ class GroupShardedStage3(nn.Layer):
for grad_storage in self._grad_storages.values(): for grad_storage in self._grad_storages.values():
grad_storage.buffer.scale_(scale=self._world_size_scaling) grad_storage.buffer.scale_(scale=self._world_size_scaling)
dist.all_reduce(tensor=grad_storage.buffer, group=self._group) dist.all_reduce(tensor=grad_storage.buffer, group=self._group)
if self._dp_group is not None and self._dp_group.nranks > 1:
grad_storage.buffer.scale_(scale=(1.0 / self._dp_group.nranks))
dist.all_reduce(
tensor=grad_storage.buffer, group=self._dp_group
)
if self._offload: if self._offload:
for param in list(self._unslice_params): for param in list(self._unslice_params):
param._clear_data() param._clear_data()
...@@ -609,7 +626,11 @@ class GroupShardedStage3(nn.Layer): ...@@ -609,7 +626,11 @@ class GroupShardedStage3(nn.Layer):
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now # Only support sync allreduce current rank's layer now
full_grad.scale_(scale=self._world_size_scaling)
dist.all_reduce(tensor=full_grad, group=self._group) dist.all_reduce(tensor=full_grad, group=self._group)
if self._dp_group is not None and self._dp_group.nranks > 1:
full_grad.scale_(scale=1.0 / self._dp_group.nranks)
dist.all_reduce(tensor=full_grad, group=self._dp_group)
start, end = self._param2buffer[param.name][self._rank] start, end = self._param2buffer[param.name][self._rank]
if param.bw_storage is None: if param.bw_storage is None:
......
...@@ -119,11 +119,7 @@ class GroupShardedClipGrad: ...@@ -119,11 +119,7 @@ class GroupShardedClipGrad:
global_unslice_fp32 = paddle.sum(global_unslice_fp32) global_unslice_fp32 = paddle.sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32 global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = ( global_norm_var = global_norm_fp16 + global_norm_fp32
global_norm_fp16
+ global_norm_fp32
+ 1.0 / self._group.nranks * global_unslice_var
)
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1]) dev_id = int(self._device.split(":")[1])
...@@ -133,7 +129,7 @@ class GroupShardedClipGrad: ...@@ -133,7 +129,7 @@ class GroupShardedClipGrad:
with device_guard(dev_id, self._device.split(":")[0]): with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group) paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var = paddle.sqrt(global_norm_var) global_norm_var = paddle.sqrt(global_norm_var + global_unslice_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
) )
...@@ -150,9 +146,9 @@ class GroupShardedClipGrad: ...@@ -150,9 +146,9 @@ class GroupShardedClipGrad:
origin_state = g.stop_gradient origin_state = g.stop_gradient
g.stop_gradient = True g.stop_gradient = True
if p.dtype == paddle.float16: if p.dtype == paddle.float16:
g.scale_(clip_var_fp16.item()) g.scale_(clip_var_fp16)
else: else:
g.scale_(clip_var.item()) g.scale_(clip_var)
g.stop_gradient = origin_state g.stop_gradient = origin_state
# p._reset_grad_inplace_version(True) # p._reset_grad_inplace_version(True)
......
...@@ -79,7 +79,7 @@ def group_sharded_parallel( ...@@ -79,7 +79,7 @@ def group_sharded_parallel(
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23. buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now. dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
Returns: Returns:
model: A wrapper for group sharded given model. model: A wrapper for group sharded given model.
...@@ -192,6 +192,7 @@ def group_sharded_parallel( ...@@ -192,6 +192,7 @@ def group_sharded_parallel(
segment_size=segment_size, segment_size=segment_size,
offload=offload, offload=offload,
sync_comm=sync_comm, sync_comm=sync_comm,
dp_group=dp_group,
device=device, device=device,
) )
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册