提交 cc12db09 编写于 作者: H HydrogenSulfate

fix Constant learning rate bug

上级 9e30739c
...@@ -93,6 +93,25 @@ class LRBase(object): ...@@ -93,6 +93,25 @@ class LRBase(object):
return warmup_lr return warmup_lr
class ConstantImpl(lr.LRScheduler):
"""Constant learning rate Class implementation
Args:
learning_rate (float): The initial learning rate
last_epoch (int, optional): The index of last epoch. Default: -1.
"""
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
super(ConstantImpl, self).__init__()
def get_lr(self) -> float:
"""always return the same learning rate
"""
return self.learning_rate
class Constant(LRBase): class Constant(LRBase):
"""Constant learning rate """Constant learning rate
...@@ -120,17 +139,9 @@ class Constant(LRBase): ...@@ -120,17 +139,9 @@ class Constant(LRBase):
last_epoch, by_epoch) last_epoch, by_epoch)
def __call__(self): def __call__(self):
learning_rate = lr.LRScheduler( learning_rate = ConstantImpl(
learning_rate=self.learning_rate, last_epoch=self.last_epoch) learning_rate=self.learning_rate, last_epoch=self.last_epoch)
def make_get_lr():
def get_lr(self):
return self.learning_rate
return get_lr
setattr(learning_rate, "get_lr", make_get_lr())
if self.warmup_steps > 0: if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate) learning_rate = self.linear_warmup(learning_rate)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册