scheduler.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Modified from espnet(https://github.com/espnet/espnet)
15 16 17
from typing import Any
from typing import Dict
from typing import Text
18 19 20 21 22
from typing import Union

from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types

23 24 25
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.dynamic_import import instance_class
from paddlespeech.s2t.utils.log import Log
26

27
__all__ = ["WarmupLR", "LRSchedulerFactory"]
28 29 30

logger = Log(__name__).getlog()

31 32 33 34 35
SCHEDULER_DICT = {
    "noam": "paddle.optimizer.lr:NoamDecay",
    "expdecaylr": "paddle.optimizer.lr:ExponentialDecay",
    "piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay",
}
36

37 38 39 40 41 42 43 44 45

def register_scheduler(cls):
    """Register scheduler."""
    alias = cls.__name__.lower()
    SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
    return cls


@register_scheduler
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class WarmupLR(LRScheduler):
    """The WarmupLR scheduler
    This scheduler is almost same as NoamLR Scheduler except for following
    difference:
    NoamLR:
        lr = optimizer.lr * model_size ** -0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)
    WarmupLR:
        lr = optimizer.lr * warmup_step ** 0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)
    Note that the maximum lr equals to optimizer.lr in this scheduler.
    """

    def __init__(self,
                 warmup_steps: Union[int, float]=25000,
                 learning_rate=1.0,
                 last_epoch=-1,
63 64
                 verbose=False,
                 **kwargs):
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        assert check_argument_types()
        self.warmup_steps = warmup_steps
        super().__init__(learning_rate, last_epoch, verbose)

    def __repr__(self):
        return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"

    def get_lr(self):
        step_num = self.last_epoch + 1
        return self.base_lr * self.warmup_steps**0.5 * min(
            step_num**-0.5, step_num * self.warmup_steps**-1.5)

    def set_step(self, step: int=None):
        '''
        It will update the learning rate in optimizer according to current ``epoch`` .  
        The new learning rate will take effect on next ``optimizer.step`` .
        
        Args:
            step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
        Returns:
            None
        '''
        self.step(epoch=step)
88 89


H
Hui Zhang 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
@register_scheduler
class ConstantLR(LRScheduler):
    """
    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
    
    Returns:
        ``ConstantLR`` instance to schedule learning rate.
    """

    def __init__(self, learning_rate, last_epoch=-1, verbose=False):
        super().__init__(learning_rate, last_epoch, verbose)

    def get_lr(self):
        return self.base_lr


def dynamic_import_scheduler(module):
    """Import Scheduler class dynamically.

    Args:
        module (str): module_name:class_name or alias in `SCHEDULER_DICT`

    Returns:
        type: Scheduler class

    """
    module_class = dynamic_import(module, SCHEDULER_DICT)
    assert issubclass(module_class,
                      LRScheduler), f"{module} does not implement LRScheduler"
    return module_class


125 126 127 128 129
class LRSchedulerFactory():
    @classmethod
    def from_args(cls, name: str, args: Dict[Text, Any]):
        module_class = dynamic_import_scheduler(name.lower())
        return instance_class(module_class, args)