未验证 提交 3addd568 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support grad division to nranks before reduce in sharding stage2 (#47764)

上级 7964119b
......@@ -498,12 +498,7 @@ class GroupShardedOptimizerStage2(Optimizer):
with device_guard(self._rank, self.offload_device):
self.offload_grads.buffer.zero_()
def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
def _step(self):
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
......@@ -536,6 +531,14 @@ class GroupShardedOptimizerStage2(Optimizer):
# Synchronize all the updated shards in between the ranks
self._broadcast_params()
def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
self._step()
def minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()"
......
......@@ -225,13 +225,13 @@ class GroupShardedStage2(nn.Layer):
def _grad_scale(self):
"""
Before the gradient accumulation, scale the gradient.
Before the optimization, scale the gradients before allreduce of dp_group.
"""
if self._dp_group is None or self._dp_group.nranks <= 1:
scale_factor = self._world_size_scaling
return
else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
scale_factor = 1.0 / (self._dp_group.nranks)
# Scale grad storages
for dtype in self._grad_storages.keys():
......@@ -366,6 +366,13 @@ class GroupShardedStage2(nn.Layer):
), "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)
def _get_scaled_grad_fn(self):
@paddle.autograd.no_grad()
def scale(grad):
grad.scale_(self._world_size_scaling)
return scale
def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
......@@ -510,6 +517,8 @@ class GroupShardedStage2(nn.Layer):
return
for index, param in enumerate(self._trainable_params):
param._register_grad_hook(self._get_scaled_grad_fn())
dst_rank = self._trainable_param2rank[param.name]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
......
......@@ -153,16 +153,6 @@ def test_sharding_api():
list(range(paddle.distributed.get_world_size()))
)
stage2_dp_params = train_mlp(
mlp1,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2',
sync_buffers=True,
dp_group=dp_group,
)
# fp16
stage2_params = train_mlp(
mlp1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册