diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 230ad8ade034e764d3d6fd86056c5035eb44feb6..b1a572d4edfc30d9fdccc45b1b056ef7411cf44d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -580,6 +580,13 @@ class GroupShardedOptimizerStage2(Optimizer): return __impl__ + def set_lr(self, lr): + super().set_lr(lr) + self._optim.set_lr(lr) + + def get_lr(self): + return self._optim.get_lr() + @paddle.autograd.no_grad() def _broadcast_params_overlap_forward(self): # Exchange all the shards with the other ranks, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py index 4d9f43db7fa6c3e94d7c7553baeb6fa0dafc8ef4..ef3bf9df182058df5f4cc92ad3aa7dd12ae96bb2 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py @@ -100,6 +100,10 @@ def train_mlp( dp_group=dp_group, ) + # just for test_coverage. + if shard_level == "os_g": + optimizer.set_lr(optimizer.get_lr()) + train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True )