learning_rate_scheduler.py 7.1 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

M
minqiyang 已提交
17 18
import math

M
minqiyang 已提交
19 20
from .. import unique_name

21
__all__ = [
M
minqiyang 已提交
22 23
    'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
    'InverseTimeDecay', 'CosineDecay'
24
]
M
minqiyang 已提交
25 26 27 28 29 30 31


class LearningRateDecay(object):
    """
    Base class of learning rate decay
    """

M
minqiyang 已提交
32 33 34
    def __init__(self, begin=0, step=1, dtype='float32'):
        self.step_num = begin
        self.step_size = step
M
minqiyang 已提交
35 36 37 38 39
        self.dtype = dtype

    def __call__(self):
        lr = self.step()
        if isinstance(lr, float):
M
minqiyang 已提交
40
            lr = self.create_lr_var(lr)
M
minqiyang 已提交
41
        self.step_num += self.step_size
M
minqiyang 已提交
42 43
        return lr

M
minqiyang 已提交
44 45
    def create_lr_var(self, lr):
        from .. import layers
M
minqiyang 已提交
46 47 48 49 50 51
        lr = layers.create_global_var(
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(lr),
            dtype=self.dtype,
            persistable=True)
M
minqiyang 已提交
52
        return lr
M
minqiyang 已提交
53 54 55 56 57

    def step(self):
        raise NotImplementedError()


M
minqiyang 已提交
58 59 60
class PiecewiseDecay(LearningRateDecay):
    def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
        super(PiecewiseDecay, self).__init__(begin, step, dtype)
M
minqiyang 已提交
61 62 63 64 65 66 67 68
        self.boundaries = boundaries
        self.values = values

        self.vars = []
        for value in values:
            self.vars.append(self.create_lr_var(value))

    def step(self):
M
minqiyang 已提交
69 70
        for i in range(len(self.boundaries)):
            if self.step_num < self.boundaries[i]:
M
minqiyang 已提交
71
                return self.vars[i]
M
minqiyang 已提交
72
        return self.vars[len(self.values) - 1]
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171


class NaturalExpDecay(LearningRateDecay):
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(NaturalExpDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)
        decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
                                                     div_res)

        return decayed_lr


class ExponentialDecay(LearningRateDecay):
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(ExponentialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate * (self.decay_rate**div_res)

        return decayed_lr


class InverseTimeDecay(LearningRateDecay):
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(InverseTimeDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)

        return decayed_lr


class PolynomialDecay(LearningRateDecay):
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_learning_rate=0.0001,
                 power=1.0,
                 cycle=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(PolynomialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle

    def step(self):
        from .. import layers
M
minqiyang 已提交
172 173
        tmp_step_num = self.step_num
        tmp_decay_steps = self.decay_steps
174 175
        if self.cycle:
            div_res = layers.ceil(
M
minqiyang 已提交
176
                self.create_lr_var(tmp_step_num / self.decay_steps))
177 178 179
            zero_var = 0.0
            one_var = 1.0

M
minqiyang 已提交
180
            if float(tmp_step_num) == zero_var:
181
                div_res = one_var
M
minqiyang 已提交
182
            tmp_decay_steps = self.decay_steps * div_res
183
        else:
M
minqiyang 已提交
184 185 186 187 188 189 190
            tmp_step_num = self.create_lr_var(tmp_step_num
                                              if tmp_step_num < self.decay_steps
                                              else self.decay_steps)

        decayed_lr = (self.learning_rate - self.end_learning_rate) * \
            ((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
        return decayed_lr
191

M
minqiyang 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

class CosineDecay(LearningRateDecay):
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(CosineDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.step_each_epoch = step_each_epoch
        self.epochs = epochs

    def step(self):
        from .. import layers
        cur_epoch = layers.floor(
            self.create_lr_var(self.step_num / self.step_each_epoch))
        decayed_lr = self.learning_rate * 0.5 * (
            layers.cos(cur_epoch * math.pi / self.epochs) + 1)
        return decayed_lr


class NoamDecay(LearningRateDecay):
    def __init__(self, d_model, warmup_steps, begin=1, step=1, dtype='float32'):
        super(NoamDecay, self).__init__(begin, step, dtype)
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def step(self):
        from .. import layers
M
minqiyang 已提交
223 224 225
        a = self.create_lr_var(self.step_num**-0.5)
        b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
        lr_value = (self.d_model**-0.5) * layers.elementwise_min(a, b)
M
minqiyang 已提交
226
        return lr_value