From 8f1e24d5297102ac9ca68405406059d0caeb10c6 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Fri, 9 Dec 2022 10:44:50 +0800 Subject: [PATCH] add set_lr & get_lr for stage2 optimizer. (#48857) --- .../sharding/group_sharded_optimizer_stage2.py | 7 +++++++ .../collective/fleet/dygraph_group_sharded_api_eager.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 230ad8ade0..b1a572d4ed 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -580,6 +580,13 @@ class GroupShardedOptimizerStage2(Optimizer): return __impl__ + def set_lr(self, lr): + super().set_lr(lr) + self._optim.set_lr(lr) + + def get_lr(self): + return self._optim.get_lr() + @paddle.autograd.no_grad() def _broadcast_params_overlap_forward(self): # Exchange all the shards with the other ranks, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py index 4d9f43db7f..ef3bf9df18 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py @@ -100,6 +100,10 @@ def train_mlp( dp_group=dp_group, ) + # just for test_coverage. + if shard_level == "os_g": + optimizer.set_lr(optimizer.get_lr()) + train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True ) -- GitLab