optimizer.py 12.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import sys
Q
qingqing01 已提交
20 21 22 23 24 25 26 27
import math
import paddle
import paddle.nn as nn

import paddle.optimizer as optimizer
import paddle.regularizer as regularizer

from ppdet.core.workspace import register, serializable
28
import copy
Q
qingqing01 已提交
29

30 31
from .adamw import AdamWDL, build_adamwdl

Q
qingqing01 已提交
32 33 34 35 36 37
__all__ = ['LearningRate', 'OptimizerBuilder']

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)


38 39 40 41 42 43 44 45 46
@serializable
class CosineDecay(object):
    """
    Cosine learning rate decay

    Args:
        max_epochs (int): max epochs for the training process.
            if you commbine cosine decay with warmup, it is recommended that
            the max_iters is much larger than the warmup iter
47 48 49 50
        use_warmup (bool): whether to use warmup. Default: True.
        min_lr_ratio (float): minimum learning rate ratio. Default: 0.
        last_plateau_epochs (int): use minimum learning rate in
            the last few epochs. Default: 0.
51 52
    """

53 54 55 56 57
    def __init__(self,
                 max_epochs=1000,
                 use_warmup=True,
                 min_lr_ratio=0.,
                 last_plateau_epochs=0):
58 59
        self.max_epochs = max_epochs
        self.use_warmup = use_warmup
60 61
        self.min_lr_ratio = min_lr_ratio
        self.last_plateau_epochs = last_plateau_epochs
62 63 64 65 66 67 68 69 70

    def __call__(self,
                 base_lr=None,
                 boundary=None,
                 value=None,
                 step_per_epoch=None):
        assert base_lr is not None, "either base LR or values should be provided"

        max_iters = self.max_epochs * int(step_per_epoch)
71 72
        last_plateau_iters = self.last_plateau_epochs * int(step_per_epoch)
        min_lr = base_lr * self.min_lr_ratio
73
        if boundary is not None and value is not None and self.use_warmup:
74
            # use warmup
M
minghaoBD 已提交
75
            warmup_iters = len(boundary)
76 77
            for i in range(int(boundary[-1]), max_iters):
                boundary.append(i)
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
                if i < max_iters - last_plateau_iters:
                    decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
                        (i - warmup_iters) * math.pi /
                        (max_iters - warmup_iters - last_plateau_iters)) + 1)
                    value.append(decayed_lr)
                else:
                    value.append(min_lr)
            return optimizer.lr.PiecewiseDecay(boundary, value)
        elif last_plateau_iters > 0:
            # not use warmup, but set `last_plateau_epochs` > 0
            boundary = []
            value = []
            for i in range(max_iters):
                if i < max_iters - last_plateau_iters:
                    decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
                        i * math.pi / (max_iters - last_plateau_iters)) + 1)
                    value.append(decayed_lr)
                else:
                    value.append(min_lr)
                if i > 0:
                    boundary.append(i)
99 100
            return optimizer.lr.PiecewiseDecay(boundary, value)

W
Wenyu 已提交
101
        return optimizer.lr.CosineAnnealingDecay(
102
            base_lr, T_max=max_iters, eta_min=min_lr)
103 104


Q
qingqing01 已提交
105 106 107 108 109 110 111 112 113 114
@serializable
class PiecewiseDecay(object):
    """
    Multi step learning rate decay

    Args:
        gamma (float | list): decay factor
        milestones (list): steps at which to decay learning rate
    """

115 116 117 118 119
    def __init__(self,
                 gamma=[0.1, 0.01],
                 milestones=[8, 11],
                 values=None,
                 use_warmup=True):
Q
qingqing01 已提交
120 121 122 123 124 125 126 127
        super(PiecewiseDecay, self).__init__()
        if type(gamma) is not list:
            self.gamma = []
            for i in range(len(milestones)):
                self.gamma.append(gamma / 10**i)
        else:
            self.gamma = gamma
        self.milestones = milestones
128 129
        self.values = values
        self.use_warmup = use_warmup
Q
qingqing01 已提交
130 131 132 133 134 135

    def __call__(self,
                 base_lr=None,
                 boundary=None,
                 value=None,
                 step_per_epoch=None):
136
        if boundary is not None and self.use_warmup:
Q
qingqing01 已提交
137
            boundary.extend([int(step_per_epoch) * i for i in self.milestones])
138 139 140
        else:
            # do not use LinearWarmup
            boundary = [int(step_per_epoch) * i for i in self.milestones]
G
George Ni 已提交
141
            value = [base_lr]  # during step[0, boundary[0]] is base_lr
Q
qingqing01 已提交
142

S
shangliang Xu 已提交
143
        # self.values is setted directly in config
144 145 146 147 148
        if self.values is not None:
            assert len(self.milestones) + 1 == len(self.values)
            return optimizer.lr.PiecewiseDecay(boundary, self.values)

        # value is computed by self.gamma
149 150 151
        value = value if value is not None else [base_lr]
        for i in self.gamma:
            value.append(base_lr * i)
Q
qingqing01 已提交
152 153 154 155 156 157 158 159 160 161 162 163

        return optimizer.lr.PiecewiseDecay(boundary, value)


@serializable
class LinearWarmup(object):
    """
    Warm up learning rate linearly

    Args:
        steps (int): warm up steps
        start_factor (float): initial learning rate factor
164 165
        epochs (int|None): use epochs as warm up steps, the priority
            of `epochs` is higher than `steps`. Default: None.
Q
qingqing01 已提交
166 167
    """

168
    def __init__(self, steps=500, start_factor=1. / 3, epochs=None):
Q
qingqing01 已提交
169 170 171
        super(LinearWarmup, self).__init__()
        self.steps = steps
        self.start_factor = start_factor
172
        self.epochs = epochs
Q
qingqing01 已提交
173

G
George Ni 已提交
174
    def __call__(self, base_lr, step_per_epoch):
Q
qingqing01 已提交
175 176
        boundary = []
        value = []
177 178
        warmup_steps = self.epochs * step_per_epoch \
            if self.epochs is not None else self.steps
179
        warmup_steps = max(warmup_steps, 1)
180 181 182
        for i in range(warmup_steps + 1):
            if warmup_steps > 0:
                alpha = i / warmup_steps
183 184 185
                factor = self.start_factor * (1 - alpha) + alpha
                lr = base_lr * factor
                value.append(lr)
Q
qingqing01 已提交
186 187 188 189 190
            if i > 0:
                boundary.append(i)
        return boundary, value


F
Feng Ni 已提交
191 192 193 194 195 196 197 198
@serializable
class ExpWarmup(object):
    """
    Warm up learning rate in exponential mode
    Args:
        steps (int): warm up steps.
        epochs (int|None): use epochs as warm up steps, the priority
            of `epochs` is higher than `steps`. Default: None.
199
        power (int): Exponential coefficient. Default: 2.
F
Feng Ni 已提交
200 201
    """

202
    def __init__(self, steps=1000, epochs=None, power=2):
F
Feng Ni 已提交
203 204 205
        super(ExpWarmup, self).__init__()
        self.steps = steps
        self.epochs = epochs
206
        self.power = power
F
Feng Ni 已提交
207 208 209 210 211

    def __call__(self, base_lr, step_per_epoch):
        boundary = []
        value = []
        warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
212
        warmup_steps = max(warmup_steps, 1)
F
Feng Ni 已提交
213
        for i in range(warmup_steps + 1):
214
            factor = (i / float(warmup_steps))**self.power
F
Feng Ni 已提交
215 216 217 218 219 220
            value.append(base_lr * factor)
            if i > 0:
                boundary.append(i)
        return boundary, value


Q
qingqing01 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
@register
class LearningRate(object):
    """
    Learning Rate configuration

    Args:
        base_lr (float): base learning rate
        schedulers (list): learning rate schedulers
    """
    __category__ = 'optim'

    def __init__(self,
                 base_lr=0.01,
                 schedulers=[PiecewiseDecay(), LinearWarmup()]):
        super(LearningRate, self).__init__()
        self.base_lr = base_lr
237 238 239 240 241 242 243 244 245 246 247 248
        self.schedulers = []

        schedulers = copy.deepcopy(schedulers)
        for sched in schedulers:
            if isinstance(sched, dict):
                # support dict sched instantiate
                module = sys.modules[__name__]
                type = sched.pop("name")
                scheduler = getattr(module, type)(**sched)
                self.schedulers.append(scheduler)
            else:
                self.schedulers.append(sched)
Q
qingqing01 已提交
249 250

    def __call__(self, step_per_epoch):
251 252 253 254 255
        assert len(self.schedulers) >= 1
        if not self.schedulers[0].use_warmup:
            return self.schedulers[0](base_lr=self.base_lr,
                                      step_per_epoch=step_per_epoch)

S
shangliang Xu 已提交
256
        # TODO: split warmup & decay
Q
qingqing01 已提交
257
        # warmup
G
George Ni 已提交
258
        boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
Q
qingqing01 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
        # decay
        decay_lr = self.schedulers[0](self.base_lr, boundary, value,
                                      step_per_epoch)
        return decay_lr


@register
class OptimizerBuilder():
    """
    Build optimizer handles
    Args:
        regularizer (object): an `Regularizer` instance
        optimizer (object): an `Optimizer` instance
    """
    __category__ = 'optim'

    def __init__(self,
                 clip_grad_by_norm=None,
277
                 clip_grad_by_value=None,
Q
qingqing01 已提交
278 279 280 281 282
                 regularizer={'type': 'L2',
                              'factor': .0001},
                 optimizer={'type': 'Momentum',
                            'momentum': .9}):
        self.clip_grad_by_norm = clip_grad_by_norm
283
        self.clip_grad_by_value = clip_grad_by_value
Q
qingqing01 已提交
284 285 286
        self.regularizer = regularizer
        self.optimizer = optimizer

W
Wenyu 已提交
287
    def __call__(self, learning_rate, model=None):
Q
qingqing01 已提交
288
        if self.clip_grad_by_norm is not None:
W
wangxinxin08 已提交
289
            grad_clip = nn.ClipGradByGlobalNorm(
Q
qingqing01 已提交
290
                clip_norm=self.clip_grad_by_norm)
291 292 293
        elif self.clip_grad_by_value is not None:
            var = abs(self.clip_grad_by_value)
            grad_clip = nn.ClipGradByValue(min=-var, max=var)
Q
qingqing01 已提交
294 295
        else:
            grad_clip = None
296
        if self.regularizer and self.regularizer != 'None':
Q
qingqing01 已提交
297 298 299 300 301 302 303 304 305
            reg_type = self.regularizer['type'] + 'Decay'
            reg_factor = self.regularizer['factor']
            regularization = getattr(regularizer, reg_type)(reg_factor)
        else:
            regularization = None

        optim_args = self.optimizer.copy()
        optim_type = optim_args['type']
        del optim_args['type']
306 307 308 309

        if optim_type == 'AdamWDL':
            return build_adamwdl(model, lr=learning_rate, **optim_args)

310 311
        if optim_type != 'AdamW':
            optim_args['weight_decay'] = regularization
312

Q
qingqing01 已提交
313
        op = getattr(optimizer, optim_type)
W
Wenyu 已提交
314

W
Wenyu 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327
        if 'param_groups' in optim_args:
            assert isinstance(optim_args['param_groups'], list), ''

            param_groups = optim_args.pop('param_groups')

            params, visited = [], []
            for group in param_groups:
                assert isinstance(group,
                                  dict) and 'params' in group and isinstance(
                                      group['params'], list), ''
                _params = {
                    n: p
                    for n, p in model.named_parameters()
328
                    if any([k in n
D
Double_V 已提交
329
                            for k in group['params']]) and p.trainable is True
W
Wenyu 已提交
330 331 332 333 334 335 336 337
                }
                _group = group.copy()
                _group.update({'params': list(_params.values())})

                params.append(_group)
                visited.extend(list(_params.keys()))

            ext_params = [
338 339
                p for n, p in model.named_parameters()
                if n not in visited and p.trainable is True
W
Wenyu 已提交
340 341 342 343 344 345 346 347
            ]

            if len(ext_params) < len(model.parameters()):
                params.append({'params': ext_params})

            elif len(ext_params) > len(model.parameters()):
                raise RuntimeError

W
Wenyu 已提交
348
        else:
349 350
            _params = model.parameters()
            params = [param for param in _params if param.trainable is True]
W
Wenyu 已提交
351

Q
qingqing01 已提交
352
        return op(learning_rate=learning_rate,
353
                  parameters=params,
Q
qingqing01 已提交
354 355
                  grad_clip=grad_clip,
                  **optim_args)