diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 5876df9d3ff74e417dcd2e97ddae4c17a2f9edae..d0c33a5f5964d5e68b70c5e342213fbb3768730c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -205,7 +205,7 @@ class GroupShardedStage2(nn.Layer): Before the gradient accumulation, scale the gradient. """ - if self._dp_group is None: + if self._dp_group is None or self._dp_group.nranks <= 1: scale_factor = self._world_size_scaling else: scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks) @@ -296,7 +296,7 @@ class GroupShardedStage2(nn.Layer): self._group, sync_op=True) - if self._dp_group: + if self._dp_group and self._dp_group.nranks > 1: collective.broadcast(buffer, self._dp_group.ranks[0], self._dp_group, @@ -369,8 +369,8 @@ class GroupShardedStage2(nn.Layer): group=self._group, sync_op=not self._reduce_overlap)) - if self._dp_group: - assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' + if self._dp_group and self._dp_group.nranks > 1: + assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. collective.all_reduce(tensor=param.grad, group=self._dp_group, @@ -426,8 +426,8 @@ class GroupShardedStage2(nn.Layer): group=self._group, sync_op=not self._reduce_overlap)) - if self._dp_group: - assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' + if self._dp_group and self._dp_group.nranks > 1: + assert not self._reduce_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. collective.all_reduce(tensor=grad_storage.buffer, group=self._dp_group, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py index f733fa1c6def6440c971386fcdfb003063bbc7da..f57dfd7c98d7aca3581c51da66fa2978e219dc04 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py @@ -75,7 +75,9 @@ def train_mlp(model, shard_level, use_multi_precision, output_dir, - amp_level='O1'): + amp_level='O1', + sync_buffers=False, + dp_group=None): optimizer = optimizer_setting(model=model, use_multi_precision=use_multi_precision) model = paddle.amp.decorate(models=model, @@ -86,7 +88,9 @@ def train_mlp(model, model, optimizer, scaler = group_sharded_parallel(model=model, optimizer=optimizer, level=shard_level, - scaler=scaler) + scaler=scaler, + sync_buffers=sync_buffers, + dp_group=dp_group) train_reader = paddle.batch(reader_decorator(), batch_size=batch_size, @@ -134,6 +138,18 @@ def test_sharding_api(): output_dir = tempfile.mkdtemp() + #test sharding + dp, just for test + dp_group = paddle.distributed.new_group( + list(range(paddle.distributed.get_world_size()))) + + stage2_dp_params = train_mlp(mlp1, + shard_level="os_g", + use_multi_precision=True, + output_dir=output_dir, + amp_level='O2', + sync_buffers=True, + dp_group=dp_group) + # fp16 stage2_params = train_mlp(mlp1, shard_level="os_g",