未验证 提交 a95b6f33 编写于 作者: W wuhuachaocoding 提交者: GitHub

combine dp and stage2 hybrid parallel. (#46795)

* combine dp and stage2 hybrid parallel.

* update condition.
上级 71748805
...@@ -69,6 +69,7 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -69,6 +69,7 @@ class GroupShardedOptimizerStage2(Optimizer):
offload=False, offload=False,
device="gpu", device="gpu",
pertrain_sync_models=True, pertrain_sync_models=True,
dp_group=None,
**kw): **kw):
super().__init__(learning_rate=optim._learning_rate, parameters=params) super().__init__(learning_rate=optim._learning_rate, parameters=params)
...@@ -121,6 +122,8 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -121,6 +122,8 @@ class GroupShardedOptimizerStage2(Optimizer):
self._group = new_group( self._group = new_group(
_get_global_group().ranks) if group is None else group _get_global_group().ranks) if group is None else group
# only support to combine stage2 and dp hybrid parallel now.
self._dp_group = dp_group
self.world_size = self._group.nranks self.world_size = self._group.nranks
self._rank = self._group.rank self._rank = self._group.rank
self._global_root_rank = self._group.ranks[0] self._global_root_rank = self._group.ranks[0]
...@@ -172,6 +175,12 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -172,6 +175,12 @@ class GroupShardedOptimizerStage2(Optimizer):
group=self._group, group=self._group,
sync_op=True) sync_op=True)
if self._dp_group:
broadcast(p,
src=self._dp_group.ranks[0],
group=self._dp_group,
sync_op=True)
def _update_task(self, task): def _update_task(self, task):
if self._reduce_overlap: if self._reduce_overlap:
assert task is not None assert task is not None
......
...@@ -65,7 +65,8 @@ class GroupShardedStage2(nn.Layer): ...@@ -65,7 +65,8 @@ class GroupShardedStage2(nn.Layer):
sync_buffers=False, sync_buffers=False,
buffer_max_size=2**23, #8MB buffer_max_size=2**23, #8MB
auto_refresh_trainable=True, auto_refresh_trainable=True,
device="gpu"): device="gpu",
dp_group=None):
super().__init__() super().__init__()
# training options # training options
...@@ -91,6 +92,8 @@ class GroupShardedStage2(nn.Layer): ...@@ -91,6 +92,8 @@ class GroupShardedStage2(nn.Layer):
0] # picking ranks index 0 as the reference 0] # picking ranks index 0 as the reference
self._default_device = device self._default_device = device
self._dp_group = dp_group
# Global statistical parameters # Global statistical parameters
self._all_params = [] self._all_params = []
for optim in self._sharding_optimizers: for optim in self._sharding_optimizers:
...@@ -201,24 +204,29 @@ class GroupShardedStage2(nn.Layer): ...@@ -201,24 +204,29 @@ class GroupShardedStage2(nn.Layer):
""" """
Before the gradient accumulation, scale the gradient. Before the gradient accumulation, scale the gradient.
""" """
if self._dp_group is None:
scale_factor = self._world_size_scaling
else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
# Scale grad storages # Scale grad storages
for dtype in self._grad_storages.keys(): for dtype in self._grad_storages.keys():
if not self._offload and self._rank in self._grad_storages[ if not self._offload and self._rank in self._grad_storages[
dtype].keys(): dtype].keys():
self._grad_storages[dtype][self._rank].buffer.scale_( self._grad_storages[dtype][self._rank].buffer.scale_(
scale=self._world_size_scaling) scale=scale_factor)
# Scale grads of params # Scale grads of params
with paddle.no_grad(): with paddle.no_grad():
for param in self._trainable_params: for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None: if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling) param.grad.scale_(scale=scale_factor)
# param._reset_grad_inplace_version(True) # param._reset_grad_inplace_version(True)
# Scale grads of master params with offload strategy # Scale grads of master params with offload strategy
if self._offload: if self._offload:
self._sharding_optimizers[0]._offload_scale_grad( self._sharding_optimizers[0]._offload_scale_grad(scale_factor)
self._world_size_scaling)
def _init_internal_storage(self, needs_fresh): def _init_internal_storage(self, needs_fresh):
""" """
...@@ -288,6 +296,12 @@ class GroupShardedStage2(nn.Layer): ...@@ -288,6 +296,12 @@ class GroupShardedStage2(nn.Layer):
self._group, self._group,
sync_op=True) sync_op=True)
if self._dp_group:
collective.broadcast(buffer,
self._dp_group.ranks[0],
self._dp_group,
sync_op=True)
def __getattr__(self, name): def __getattr__(self, name):
"""Forward missing attributes to wrapped layer.""" """Forward missing attributes to wrapped layer."""
try: try:
...@@ -355,6 +369,13 @@ class GroupShardedStage2(nn.Layer): ...@@ -355,6 +369,13 @@ class GroupShardedStage2(nn.Layer):
group=self._group, group=self._group,
sync_op=not self._reduce_overlap)) sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=param.grad,
group=self._dp_group,
sync_op=True)
# Clear the task flow and trigger callback to clear the redundant gradient # Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow() # self._clear_task_flow()
...@@ -405,6 +426,13 @@ class GroupShardedStage2(nn.Layer): ...@@ -405,6 +426,13 @@ class GroupShardedStage2(nn.Layer):
group=self._group, group=self._group,
sync_op=not self._reduce_overlap)) sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=grad_storage.buffer,
group=self._dp_group,
sync_op=True)
cleanup() cleanup()
# Clear the task flow and trigger callback to clear the redundant gradient # Clear the task flow and trigger callback to clear the redundant gradient
......
...@@ -45,7 +45,8 @@ def group_sharded_parallel(model, ...@@ -45,7 +45,8 @@ def group_sharded_parallel(model,
sync_buffers=False, sync_buffers=False,
buffer_max_size=2**23, buffer_max_size=2**23,
segment_size=2**20, segment_size=2**20,
sync_comm=False): sync_comm=False,
dp_group=None):
""" """
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation. Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
...@@ -61,6 +62,7 @@ def group_sharded_parallel(model, ...@@ -61,6 +62,7 @@ def group_sharded_parallel(model,
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23. buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now.
Returns: Returns:
model: A wrapper for group sharded given model. model: A wrapper for group sharded given model.
...@@ -123,12 +125,14 @@ def group_sharded_parallel(model, ...@@ -123,12 +125,14 @@ def group_sharded_parallel(model,
params=optimizer._parameter_list, params=optimizer._parameter_list,
optim=optimizer, optim=optimizer,
group=group, group=group,
offload=offload) offload=offload,
dp_group=dp_group)
model = GroupShardedStage2(model, model = GroupShardedStage2(model,
optimizer, optimizer,
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size) buffer_max_size=buffer_max_size,
dp_group=dp_group)
else: else:
optimizer = ShardingOptimizerStage2(params=model.parameters(), optimizer = ShardingOptimizerStage2(params=model.parameters(),
optim=optimizer, optim=optimizer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册