From 1c8531ce9ff6d3935b5216cb6943f68b4662a711 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Fri, 13 Jan 2023 15:18:06 +0800 Subject: [PATCH] fix a bug of stage2 offload. (#49767) --- .../group_sharded_optimizer_stage2.py | 6 +++++ .../dygraph_group_sharded_stage2_offload.py | 25 ++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 9a25d7c491..f5ca60b100 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -149,6 +149,11 @@ class GroupShardedOptimizerStage2(Optimizer): self._rank = self._group.rank self._global_root_rank = self._group.ranks[0] + if self._dp_group is not None and self._dp_group.nranks > 1: + assert ( + not offload + ), "Not support! when using offload with sharding stage2, please use pure sharding stage2, exclude data parallel." + # Synchronous all ranks models if pertrain_sync_models: self._sync_params_and_buffers() @@ -164,6 +169,7 @@ class GroupShardedOptimizerStage2(Optimizer): if ( hcg and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + and not offload ): self._optim._grad_clip = HybridParallelClipGrad( self._optim._grad_clip, hcg diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py index 8b6b9241bd..e868b4ff34 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py @@ -37,17 +37,29 @@ np.random.seed(seed) paddle.seed(seed) -def train_mlp(model, offload=False): +def train_mlp(model, offload=False, test=False): optimizer = optimizer_setting(model=model, use_pure_fp16=True) model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = GroupShardedScaler(scaler) + dp_group = ( + None + if not test + else paddle.distributed.new_group( + list(range(paddle.distributed.get_world_size())) + ) + ) optimizer = GroupShardedOptimizerStage2( - params=optimizer._parameter_list, optim=optimizer, offload=offload + params=optimizer._parameter_list, + optim=optimizer, + offload=offload, + dp_group=dp_group, + ) + model = GroupShardedStage2( + model, optimizer, buffer_max_size=2**21, dp_group=dp_group ) - model = GroupShardedStage2(model, optimizer, buffer_max_size=2**21) paddle.seed(2023) np.random.seed(2023) @@ -103,6 +115,13 @@ def test_sharding_stage2_offload(): rtol=5e-3, atol=5e-3, ) + + # just to test assert error for the rate of coverage + try: + train_mlp(mlp_offload, offload=True, test=True) + except Exception as e: + assert isinstance(e, AssertionError) + return -- GitLab