未验证 提交 dbc3fd5e 编写于 作者: H hong 提交者: GitHub

fix LinearLrWarmup bug; test=develop (#24913)

上级 f6f7df9c
...@@ -664,6 +664,7 @@ class LinearLrWarmup(LearningRateDecay): ...@@ -664,6 +664,7 @@ class LinearLrWarmup(LearningRateDecay):
format(learning_rate)) format(learning_rate))
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.start_lr = start_lr
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format( assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
end_lr, start_lr) end_lr, start_lr)
self.lr_ratio_before_warmup = ( self.lr_ratio_before_warmup = (
...@@ -676,7 +677,7 @@ class LinearLrWarmup(LearningRateDecay): ...@@ -676,7 +677,7 @@ class LinearLrWarmup(LearningRateDecay):
from .. import layers from .. import layers
if self.step_num < self.warmup_steps: if self.step_num < self.warmup_steps:
return self.lr_ratio_before_warmup * self.step_num return self.lr_ratio_before_warmup * self.step_num + self.start_lr
else: else:
return base_lr return base_lr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册