未验证 提交 8d2ce06e 编写于 作者: W wuhuachaocoding 提交者: GitHub

add test for stage2 + dp (#47114)

* add test for stage2 + dp

* update test for stage2 + dp.

* update.

* update.
上级 f0408778
......@@ -205,7 +205,7 @@ class GroupShardedStage2(nn.Layer):
Before the gradient accumulation, scale the gradient.
"""
if self._dp_group is None:
if self._dp_group is None or self._dp_group.nranks <= 1:
scale_factor = self._world_size_scaling
else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
......@@ -296,7 +296,7 @@ class GroupShardedStage2(nn.Layer):
self._group,
sync_op=True)
if self._dp_group:
if self._dp_group and self._dp_group.nranks > 1:
collective.broadcast(buffer,
self._dp_group.ranks[0],
self._dp_group,
......@@ -369,8 +369,8 @@ class GroupShardedStage2(nn.Layer):
group=self._group,
sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
if self._dp_group and self._dp_group.nranks > 1:
assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=param.grad,
group=self._dp_group,
......@@ -426,8 +426,8 @@ class GroupShardedStage2(nn.Layer):
group=self._group,
sync_op=not self._reduce_overlap))
if self._dp_group:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
if self._dp_group and self._dp_group.nranks > 1:
assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=grad_storage.buffer,
group=self._dp_group,
......
......@@ -75,7 +75,9 @@ def train_mlp(model,
shard_level,
use_multi_precision,
output_dir,
amp_level='O1'):
amp_level='O1',
sync_buffers=False,
dp_group=None):
optimizer = optimizer_setting(model=model,
use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model,
......@@ -86,7 +88,9 @@ def train_mlp(model,
model, optimizer, scaler = group_sharded_parallel(model=model,
optimizer=optimizer,
level=shard_level,
scaler=scaler)
scaler=scaler,
sync_buffers=sync_buffers,
dp_group=dp_group)
train_reader = paddle.batch(reader_decorator(),
batch_size=batch_size,
......@@ -134,6 +138,18 @@ def test_sharding_api():
output_dir = tempfile.mkdtemp()
#test sharding + dp, just for test
dp_group = paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size())))
stage2_dp_params = train_mlp(mlp1,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2',
sync_buffers=True,
dp_group=dp_group)
# fp16
stage2_params = train_mlp(mlp1,
shard_level="os_g",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册