未验证 提交 934121ad 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2322 from HydrogenSulfate/fix_constlr

fix Constant learning rate bug
......@@ -87,7 +87,7 @@ Optimizer:
- SGD:
scope: CenterLoss
lr:
name: Constant
name: ConstLR
learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents.
# data loader for train and eval
......
......@@ -93,7 +93,26 @@ class LRBase(object):
return warmup_lr
class Constant(LRBase):
class Constant(lr.LRScheduler):
"""Constant learning rate Class implementation
Args:
learning_rate (float): The initial learning rate
last_epoch (int, optional): The index of last epoch. Default: -1.
"""
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
super(Constant, self).__init__()
def get_lr(self) -> float:
"""always return the same learning rate
"""
return self.learning_rate
class ConstLR(LRBase):
"""Constant learning rate
Args:
......@@ -115,22 +134,14 @@ class Constant(LRBase):
last_epoch=-1,
by_epoch=False,
**kwargs):
super(Constant, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr,
last_epoch, by_epoch)
super(ConstLR, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr,
last_epoch, by_epoch)
def __call__(self):
learning_rate = lr.LRScheduler(
learning_rate = Constant(
learning_rate=self.learning_rate, last_epoch=self.last_epoch)
def make_get_lr():
def get_lr(self):
return self.learning_rate
return get_lr
setattr(learning_rate, "get_lr", make_get_lr())
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册