未验证 提交 8d2ce06e 编写于 作者: W wuhuachaocoding 提交者: GitHub

add test for stage2 + dp (#47114)

* add test for stage2 + dp

* update test for stage2 + dp.

* update.

* update.
上级 f0408778
...@@ -205,7 +205,7 @@ class GroupShardedStage2(nn.Layer): ...@@ -205,7 +205,7 @@ class GroupShardedStage2(nn.Layer):
Before the gradient accumulation, scale the gradient. Before the gradient accumulation, scale the gradient.
""" """
if self._dp_group is None: if self._dp_group is None or self._dp_group.nranks <= 1:
scale_factor = self._world_size_scaling scale_factor = self._world_size_scaling
else: else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks) scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
...@@ -296,7 +296,7 @@ class GroupShardedStage2(nn.Layer): ...@@ -296,7 +296,7 @@ class GroupShardedStage2(nn.Layer):
self._group, self._group,
sync_op=True) sync_op=True)
if self._dp_group: if self._dp_group and self._dp_group.nranks > 1:
collective.broadcast(buffer, collective.broadcast(buffer,
self._dp_group.ranks[0], self._dp_group.ranks[0],
self._dp_group, self._dp_group,
...@@ -369,8 +369,8 @@ class GroupShardedStage2(nn.Layer): ...@@ -369,8 +369,8 @@ class GroupShardedStage2(nn.Layer):
group=self._group, group=self._group,
sync_op=not self._reduce_overlap)) sync_op=not self._reduce_overlap))
if self._dp_group: if self._dp_group and self._dp_group.nranks > 1:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=param.grad, collective.all_reduce(tensor=param.grad,
group=self._dp_group, group=self._dp_group,
...@@ -426,8 +426,8 @@ class GroupShardedStage2(nn.Layer): ...@@ -426,8 +426,8 @@ class GroupShardedStage2(nn.Layer):
group=self._group, group=self._group,
sync_op=not self._reduce_overlap)) sync_op=not self._reduce_overlap))
if self._dp_group: if self._dp_group and self._dp_group.nranks > 1:
assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
#TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
collective.all_reduce(tensor=grad_storage.buffer, collective.all_reduce(tensor=grad_storage.buffer,
group=self._dp_group, group=self._dp_group,
......
...@@ -75,7 +75,9 @@ def train_mlp(model, ...@@ -75,7 +75,9 @@ def train_mlp(model,
shard_level, shard_level,
use_multi_precision, use_multi_precision,
output_dir, output_dir,
amp_level='O1'): amp_level='O1',
sync_buffers=False,
dp_group=None):
optimizer = optimizer_setting(model=model, optimizer = optimizer_setting(model=model,
use_multi_precision=use_multi_precision) use_multi_precision=use_multi_precision)
model = paddle.amp.decorate(models=model, model = paddle.amp.decorate(models=model,
...@@ -86,7 +88,9 @@ def train_mlp(model, ...@@ -86,7 +88,9 @@ def train_mlp(model,
model, optimizer, scaler = group_sharded_parallel(model=model, model, optimizer, scaler = group_sharded_parallel(model=model,
optimizer=optimizer, optimizer=optimizer,
level=shard_level, level=shard_level,
scaler=scaler) scaler=scaler,
sync_buffers=sync_buffers,
dp_group=dp_group)
train_reader = paddle.batch(reader_decorator(), train_reader = paddle.batch(reader_decorator(),
batch_size=batch_size, batch_size=batch_size,
...@@ -134,6 +138,18 @@ def test_sharding_api(): ...@@ -134,6 +138,18 @@ def test_sharding_api():
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
#test sharding + dp, just for test
dp_group = paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size())))
stage2_dp_params = train_mlp(mlp1,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2',
sync_buffers=True,
dp_group=dp_group)
# fp16 # fp16
stage2_params = train_mlp(mlp1, stage2_params = train_mlp(mlp1,
shard_level="os_g", shard_level="os_g",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册