提交 8a28962c 编写于 作者: 文幕地方's avatar 文幕地方

add Const lr

上级 07633eb8
...@@ -34,10 +34,12 @@ Optimizer: ...@@ -34,10 +34,12 @@ Optimizer:
beta2: 0.999 beta2: 0.999
clip_norm: 10 clip_norm: 10
lr: lr:
name: Piecewise # name: Piecewise
values: [0.000005, 0.00005] # values: [0.000005, 0.00005]
decay_epochs: [10] # decay_epochs: [10]
warmup_epoch: 0 # warmup_epoch: 0
learning_rate: 0.00005
warmup_epoch: 10
regularizer: regularizer:
name: L2 name: L2
factor: 0.00000 factor: 0.00000
......
...@@ -34,10 +34,8 @@ Optimizer: ...@@ -34,10 +34,8 @@ Optimizer:
beta2: 0.999 beta2: 0.999
clip_norm: 10 clip_norm: 10
lr: lr:
name: Piecewise learning_rate: 0.00005
values: [0.000005, 0.00005] warmup_epoch: 10
decay_epochs: [10]
warmup_epoch: 0
regularizer: regularizer:
name: L2 name: L2
factor: 0.00000 factor: 0.00000
......
...@@ -25,11 +25,8 @@ __all__ = ['build_optimizer'] ...@@ -25,11 +25,8 @@ __all__ = ['build_optimizer']
def build_lr_scheduler(lr_config, epochs, step_each_epoch): def build_lr_scheduler(lr_config, epochs, step_each_epoch):
from . import learning_rate from . import learning_rate
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
if 'name' in lr_config: lr_name = lr_config.pop('name', 'Const')
lr_name = lr_config.pop('name') lr = getattr(learning_rate, lr_name)(**lr_config)()
lr = getattr(learning_rate, lr_name)(**lr_config)()
else:
lr = lr_config['learning_rate']
return lr return lr
......
...@@ -275,4 +275,36 @@ class OneCycle(object): ...@@ -275,4 +275,36 @@ class OneCycle(object):
start_lr=0.0, start_lr=0.0,
end_lr=self.max_lr, end_lr=self.max_lr,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
\ No newline at end of file
class Const(object):
"""
Const learning rate decay
Args:
learning_rate(float): initial learning rate
step_each_epoch(int): steps each epoch
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
step_each_epoch,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(Const, self).__init__()
self.learning_rate = learning_rate
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = self.learning_rate
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册