提交 6125fa94 编写于 作者: H HydrogenSulfate

change Constant to ConstLR

上级 cc12db09
...@@ -87,7 +87,7 @@ Optimizer: ...@@ -87,7 +87,7 @@ Optimizer:
- SGD: - SGD:
scope: CenterLoss scope: CenterLoss
lr: lr:
name: Constant name: ConstLR
learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents. learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents.
# data loader for train and eval # data loader for train and eval
......
...@@ -93,7 +93,7 @@ class LRBase(object): ...@@ -93,7 +93,7 @@ class LRBase(object):
return warmup_lr return warmup_lr
class ConstantImpl(lr.LRScheduler): class Constant(lr.LRScheduler):
"""Constant learning rate Class implementation """Constant learning rate Class implementation
Args: Args:
...@@ -104,7 +104,7 @@ class ConstantImpl(lr.LRScheduler): ...@@ -104,7 +104,7 @@ class ConstantImpl(lr.LRScheduler):
def __init__(self, learning_rate, last_epoch=-1, **kwargs): def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.last_epoch = last_epoch self.last_epoch = last_epoch
super(ConstantImpl, self).__init__() super(Constant, self).__init__()
def get_lr(self) -> float: def get_lr(self) -> float:
"""always return the same learning rate """always return the same learning rate
...@@ -112,7 +112,7 @@ class ConstantImpl(lr.LRScheduler): ...@@ -112,7 +112,7 @@ class ConstantImpl(lr.LRScheduler):
return self.learning_rate return self.learning_rate
class Constant(LRBase): class ConstLR(LRBase):
"""Constant learning rate """Constant learning rate
Args: Args:
...@@ -134,12 +134,12 @@ class Constant(LRBase): ...@@ -134,12 +134,12 @@ class Constant(LRBase):
last_epoch=-1, last_epoch=-1,
by_epoch=False, by_epoch=False,
**kwargs): **kwargs):
super(Constant, self).__init__(epochs, step_each_epoch, learning_rate, super(ConstLR, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr, warmup_epoch, warmup_start_lr,
last_epoch, by_epoch) last_epoch, by_epoch)
def __call__(self): def __call__(self):
learning_rate = ConstantImpl( learning_rate = Constant(
learning_rate=self.learning_rate, last_epoch=self.last_epoch) learning_rate=self.learning_rate, last_epoch=self.last_epoch)
if self.warmup_steps > 0: if self.warmup_steps > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册