From 107a316f85205ab7a6e9e0b78244096611278fe1 Mon Sep 17 00:00:00 2001 From: bupt906 <41312357+bupt906@users.noreply.github.com> Date: Sun, 16 Jan 2022 23:53:39 +0800 Subject: [PATCH] add onecycle (#5252) --- ppocr/optimizer/learning_rate.py | 52 +++++++++++++- ppocr/optimizer/lr_scheduler.py | 113 +++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index e1b10992..b1879f3e 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -18,7 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals from paddle.optimizer import lr -from .lr_scheduler import CyclicalCosineDecay +from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay class Linear(object): @@ -226,3 +226,53 @@ class CyclicalCosine(object): end_lr=self.learning_rate, last_epoch=self.last_epoch) return learning_rate + + +class OneCycle(object): + """ + One Cycle learning rate decay + Args: + max_lr(float): Upper learning rate boundaries + epochs(int): total training epochs + step_each_epoch(int): steps each epoch + anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing. + Default: ‘cos’ + three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’ + instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’). + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + max_lr, + epochs, + step_each_epoch, + anneal_strategy='cos', + three_phase=False, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(OneCycle, self).__init__() + self.max_lr = max_lr + self.epochs = epochs + self.steps_per_epoch = step_each_epoch + self.anneal_strategy = anneal_strategy + self.three_phase = three_phase + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = OneCycleDecay( + max_lr=self.max_lr, + epochs=self.epochs, + steps_per_epoch=self.steps_per_epoch, + anneal_strategy=self.anneal_strategy, + three_phase=self.three_phase, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.max_lr, + last_epoch=self.last_epoch) + return learning_rate \ No newline at end of file diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py index 21aec737..f62f1f3b 100644 --- a/ppocr/optimizer/lr_scheduler.py +++ b/ppocr/optimizer/lr_scheduler.py @@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler): lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \ (1 + math.cos(math.pi * reletive_epoch / self.cycle)) return lr + + +class OneCycleDecay(LRScheduler): + """ + One Cycle learning rate decay + A learning rate which can be referred in https://arxiv.org/abs/1708.07120 + Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR + """ + + def __init__(self, + max_lr, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25., + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False): + + # Validate total_steps + if epochs <= 0 or not isinstance(epochs, int): + raise ValueError( + "Expected positive integer epochs, but got {}".format(epochs)) + if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): + raise ValueError( + "Expected positive integer steps_per_epoch, but got {}".format( + steps_per_epoch)) + self.total_steps = epochs * steps_per_epoch + + self.max_lr = max_lr + self.initial_lr = self.max_lr / div_factor + self.min_lr = self.initial_lr / final_div_factor + + if three_phase: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.max_lr, + }, + { + 'end_step': float(2 * pct_start * self.total_steps) - 2, + 'start_lr': self.max_lr, + 'end_lr': self.initial_lr, + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.min_lr, + }, + ] + else: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.max_lr, + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': self.max_lr, + 'end_lr': self.min_lr, + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + "Expected float between 0 and 1 pct_start, but got {}".format( + pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError( + "anneal_strategy must by one of 'cos' or 'linear', instead got {}". + format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose) + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + computed_lr = 0.0 + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + "Tried to step {} times. The specified number of total steps is {}" + .format(step_num + 1, self.total_steps)) + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase['end_step'] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(phase['start_lr'], + phase['end_lr'], pct) + break + start_step = phase['end_step'] + + return computed_lr -- GitLab