未验证 提交 8f1e24d5 编写于 作者: W wuhuachaocoding 提交者: GitHub

add set_lr & get_lr for stage2 optimizer. (#48857)

上级 39ffef0d
......@@ -580,6 +580,13 @@ class GroupShardedOptimizerStage2(Optimizer):
return __impl__
def set_lr(self, lr):
super().set_lr(lr)
self._optim.set_lr(lr)
def get_lr(self):
return self._optim.get_lr()
@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
......
......@@ -100,6 +100,10 @@ def train_mlp(
dp_group=dp_group,
)
# just for test_coverage.
if shard_level == "os_g":
optimizer.set_lr(optimizer.get_lr())
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册