diff --git a/configs/vqa/re/layoutlmv2.yml b/configs/vqa/re/layoutlmv2.yml index 9daa2a968e04e9e30a9938eb5c7e57c4bd84e019..b213212f15c61a4f32194f0090c92eb595681068 100644 --- a/configs/vqa/re/layoutlmv2.yml +++ b/configs/vqa/re/layoutlmv2.yml @@ -34,10 +34,12 @@ Optimizer: beta2: 0.999 clip_norm: 10 lr: - name: Piecewise - values: [0.000005, 0.00005] - decay_epochs: [10] - warmup_epoch: 0 + # name: Piecewise + # values: [0.000005, 0.00005] + # decay_epochs: [10] + # warmup_epoch: 0 + learning_rate: 0.00005 + warmup_epoch: 10 regularizer: name: L2 factor: 0.00000 diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml index d413b17494d969eceb07711d2c6679c00cb07f7f..ff16120ac1be92e989ebfda6af3ccf346dde89cd 100644 --- a/configs/vqa/re/layoutxlm.yml +++ b/configs/vqa/re/layoutxlm.yml @@ -34,10 +34,8 @@ Optimizer: beta2: 0.999 clip_norm: 10 lr: - name: Piecewise - values: [0.000005, 0.00005] - decay_epochs: [10] - warmup_epoch: 0 + learning_rate: 0.00005 + warmup_epoch: 10 regularizer: name: L2 factor: 0.00000 diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index e0c6b90371cb4b09fb894ceeaeb8595e51c6c557..4110fb47678583cff826a9bc855b3fb378a533f9 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -25,11 +25,8 @@ __all__ = ['build_optimizer'] def build_lr_scheduler(lr_config, epochs, step_each_epoch): from . import learning_rate lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) - if 'name' in lr_config: - lr_name = lr_config.pop('name') - lr = getattr(learning_rate, lr_name)(**lr_config)() - else: - lr = lr_config['learning_rate'] + lr_name = lr_config.pop('name', 'Const') + lr = getattr(learning_rate, lr_name)(**lr_config)() return lr diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index b1879f3ee509761043c1797d8b67e4e0988af130..fe251f36e736bb1eac8a71a8115c941cbd7443e6 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -275,4 +275,36 @@ class OneCycle(object): start_lr=0.0, end_lr=self.max_lr, last_epoch=self.last_epoch) - return learning_rate \ No newline at end of file + return learning_rate + + +class Const(object): + """ + Const learning rate decay + Args: + learning_rate(float): initial learning rate + step_each_epoch(int): steps each epoch + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + learning_rate, + step_each_epoch, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(Const, self).__init__() + self.learning_rate = learning_rate + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = self.learning_rate + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate