From de3e2e7cd3b8b65ee02d7a41e570fa5b511a3c1d Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Tue, 29 Dec 2020 13:49:43 +0800 Subject: [PATCH] add CyclicalCosineDecay (#1599) --- ppocr/optimizer/learning_rate.py | 51 +++++++++++++++++++++++++++++--- ppocr/optimizer/lr_scheduler.py | 49 ++++++++++++++++++++++++++++++ tools/program.py | 4 +-- 3 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 ppocr/optimizer/lr_scheduler.py diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index 8f303e83..e1b10992 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -18,6 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals from paddle.optimizer import lr +from .lr_scheduler import CyclicalCosineDecay class Linear(object): @@ -46,7 +47,7 @@ class Linear(object): self.end_lr = end_lr self.power = power self.last_epoch = last_epoch - self.warmup_epoch = warmup_epoch * step_each_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) def __call__(self): learning_rate = lr.PolynomialDecay( @@ -87,7 +88,7 @@ class Cosine(object): self.learning_rate = learning_rate self.T_max = step_each_epoch * epochs self.last_epoch = last_epoch - self.warmup_epoch = warmup_epoch * step_each_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) def __call__(self): learning_rate = lr.CosineAnnealingDecay( @@ -129,7 +130,7 @@ class Step(object): self.learning_rate = learning_rate self.gamma = gamma self.last_epoch = last_epoch - self.warmup_epoch = warmup_epoch * step_each_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) def __call__(self): learning_rate = lr.StepDecay( @@ -168,7 +169,7 @@ class Piecewise(object): self.boundaries = [step_each_epoch * e for e in decay_epochs] self.values = values self.last_epoch = last_epoch - self.warmup_epoch = warmup_epoch * step_each_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) def __call__(self): learning_rate = lr.PiecewiseDecay( @@ -183,3 +184,45 @@ class Piecewise(object): end_lr=self.values[0], last_epoch=self.last_epoch) return learning_rate + + +class CyclicalCosine(object): + """ + Cyclical cosine learning rate decay + Args: + learning_rate(float): initial learning rate + step_each_epoch(int): steps each epoch + epochs(int): total training epochs + cycle(int): period of the cosine learning rate + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + learning_rate, + step_each_epoch, + epochs, + cycle, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(CyclicalCosine, self).__init__() + self.learning_rate = learning_rate + self.T_max = step_each_epoch * epochs + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + self.cycle = round(cycle * step_each_epoch) + + def __call__(self): + learning_rate = CyclicalCosineDecay( + learning_rate=self.learning_rate, + T_max=self.T_max, + cycle=self.cycle, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py new file mode 100644 index 00000000..21aec737 --- /dev/null +++ b/ppocr/optimizer/lr_scheduler.py @@ -0,0 +1,49 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from paddle.optimizer.lr import LRScheduler + + +class CyclicalCosineDecay(LRScheduler): + def __init__(self, + learning_rate, + T_max, + cycle=1, + last_epoch=-1, + eta_min=0.0, + verbose=False): + """ + Cyclical cosine learning rate decay + A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf + Args: + learning rate(float): learning rate + T_max(int): maximum epoch num + cycle(int): period of the cosine decay + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + eta_min(float): minimum learning rate during training + verbose(bool): whether to print learning rate for each epoch + """ + super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch, + verbose) + self.cycle = cycle + self.eta_min = eta_min + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr + reletive_epoch = self.last_epoch % self.cycle + lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \ + (1 + math.cos(math.pi * reletive_epoch / self.cycle)) + return lr diff --git a/tools/program.py b/tools/program.py index c712fe14..c2915426 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 0 + start_epoch = 1 - for epoch in range(start_epoch, epoch_num): + for epoch in range(start_epoch, epoch_num + 1): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 -- GitLab