diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 9a25d7c4912bacc49c727c09958c1daaaf5c7c0c..f5ca60b1003ea0021401e66459f59304c8f9f0a1 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -149,6 +149,11 @@ class GroupShardedOptimizerStage2(Optimizer): self._rank = self._group.rank self._global_root_rank = self._group.ranks[0] + if self._dp_group is not None and self._dp_group.nranks > 1: + assert ( + not offload + ), "Not support! when using offload with sharding stage2, please use pure sharding stage2, exclude data parallel." + # Synchronous all ranks models if pertrain_sync_models: self._sync_params_and_buffers() @@ -164,6 +169,7 @@ class GroupShardedOptimizerStage2(Optimizer): if ( hcg and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL + and not offload ): self._optim._grad_clip = HybridParallelClipGrad( self._optim._grad_clip, hcg diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py index 8b6b9241bdfae24dddc7db9c3d929643592be379..e868b4ff3404da6a84d80634c202e49f66c9c2da 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_offload.py @@ -37,17 +37,29 @@ np.random.seed(seed) paddle.seed(seed) -def train_mlp(model, offload=False): +def train_mlp(model, offload=False, test=False): optimizer = optimizer_setting(model=model, use_pure_fp16=True) model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = GroupShardedScaler(scaler) + dp_group = ( + None + if not test + else paddle.distributed.new_group( + list(range(paddle.distributed.get_world_size())) + ) + ) optimizer = GroupShardedOptimizerStage2( - params=optimizer._parameter_list, optim=optimizer, offload=offload + params=optimizer._parameter_list, + optim=optimizer, + offload=offload, + dp_group=dp_group, + ) + model = GroupShardedStage2( + model, optimizer, buffer_max_size=2**21, dp_group=dp_group ) - model = GroupShardedStage2(model, optimizer, buffer_max_size=2**21) paddle.seed(2023) np.random.seed(2023) @@ -103,6 +115,13 @@ def test_sharding_stage2_offload(): rtol=5e-3, atol=5e-3, ) + + # just to test assert error for the rate of coverage + try: + train_mlp(mlp_offload, offload=True, test=True) + except Exception as e: + assert isinstance(e, AssertionError) + return