未验证 提交 a95b6f33 编写于 作者: W wuhuachaocoding 提交者: GitHub

combine dp and stage2 hybrid parallel. (#46795)

* combine dp and stage2 hybrid parallel.

* update condition.
上级 71748805
......@@ -69,6 +69,7 @@ class GroupShardedOptimizerStage2(Optimizer):
offload=False,
device="gpu",
pertrain_sync_models=True,
dp_group=None,
**kw):
super().__init__(learning_rate=optim._learning_rate, parameters=params)
......@@ -121,6 +122,8 @@ class GroupShardedOptimizerStage2(Optimizer):
self._group = new_group(
_get_global_group().ranks) if group is None else group
# only support to combine stage2 and dp hybrid parallel now.
self._dp_group = dp_group
self.world_size = self._group.nranks
self._rank = self._group.rank
self._global_root_rank = self._group.ranks[0]
......@@ -172,6 +175,12 @@ class GroupShardedOptimizerStage2(Optimizer):
group=self._group,
sync_op=True)
if self._dp_group:
broadcast(p,
src=self._dp_group.ranks[0],
group=self._dp_group,
sync_op=True)
def _update_task(self, task):
if self._reduce_overlap:
assert task is not None
......
......@@ -65,7 +65,8 @@ class GroupShardedStage2(nn.Layer):
sync_buffers=False,
buffer_max_size=2**23, #8MB
auto_refresh_trainable=True,
device="gpu"):
device="gpu",
dp_group=None):
super().__init__()
# training options
......@@ -91,6 +92,8 @@ class GroupShardedStage2(nn.Layer):
0] # picking ranks index 0 as the reference
self._default_device = device
self._dp_group = dp_group
# Global statistical parameters
self._all_params = []
for optim in self._sharding_optimizers:
......@@ -201,24 +204,29 @@ class GroupShardedStage2(nn.Layer):
"""
Before the gradient accumulation, scale the gradient.
"""
if self._dp_group is None:
scale_factor = self._world_size_scaling
else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
# Scale grad storages
for dtype in self._grad_storages.keys():
if not self._offload and self._rank in self._grad_storages[
dtype].keys():
self._grad_storages[dtype][self._rank].buffer.scale_(
scale=self._world_size_scaling)
scale=scale_factor)
# Scale grads of params
with paddle.no_grad():
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling)
param.grad.scale_(scale=scale_factor)
# param._reset_grad_inplace_version(True)
# Scale grads of master params with offload strategy
if self._offload:
self._sharding_optimizers[0]._offload_scale_grad(
self._world_size_scaling)
self._sharding_optimizers[0]._offload_scale_grad(scale_factor)
def _init_internal_storage(self, needs_fresh):
"""
......@@ -288,6 +296,12 @@ class GroupShardedStage2(nn.Layer):
self._group,
sync_op=True)
if self._dp_group:
collective.broadcast(buffer,
self._dp_group.ranks[0],
self._dp_group,
sync_op=True)
def __getattr__(self, name):
"""Forward missing attributes to wrapped layer."""
try:
......@@ -355,6 +369,13 @@ class GroupShardedStage2(nn.Layer):
group=self._group,
sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=param.grad,
group=self._dp_group,
sync_op=True)
# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
......@@ -405,6 +426,13 @@ class GroupShardedStage2(nn.Layer):
group=self._group,
sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=grad_storage.buffer,
group=self._dp_group,
sync_op=True)
cleanup()
# Clear the task flow and trigger callback to clear the redundant gradient
......
......@@ -45,7 +45,8 @@ def group_sharded_parallel(model,
sync_buffers=False,
buffer_max_size=2**23,
segment_size=2**20,
sync_comm=False):
sync_comm=False,
dp_group=None):
"""
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
......@@ -61,6 +62,7 @@ def group_sharded_parallel(model,
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now.
Returns:
model: A wrapper for group sharded given model.
......@@ -123,12 +125,14 @@ def group_sharded_parallel(model,
params=optimizer._parameter_list,
optim=optimizer,
group=group,
offload=offload)
offload=offload,
dp_group=dp_group)
model = GroupShardedStage2(model,
optimizer,
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size)
buffer_max_size=buffer_max_size,
dp_group=dp_group)
else:
optimizer = ShardingOptimizerStage2(params=model.parameters(),
optim=optimizer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册