未验证 提交 8f1e24d5 编写于 作者: W wuhuachaocoding 提交者: GitHub

add set_lr & get_lr for stage2 optimizer. (#48857)

上级 39ffef0d
...@@ -580,6 +580,13 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -580,6 +580,13 @@ class GroupShardedOptimizerStage2(Optimizer):
return __impl__ return __impl__
def set_lr(self, lr):
super().set_lr(lr)
self._optim.set_lr(lr)
def get_lr(self):
return self._optim.get_lr()
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self): def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks, # Exchange all the shards with the other ranks,
......
...@@ -100,6 +100,10 @@ def train_mlp( ...@@ -100,6 +100,10 @@ def train_mlp(
dp_group=dp_group, dp_group=dp_group,
) )
# just for test_coverage.
if shard_level == "os_g":
optimizer.set_lr(optimizer.get_lr())
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True reader_decorator(), batch_size=batch_size, drop_last=True
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册