optimizer.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import math
W
wangxinxin08 已提交
20 21
import paddle
import paddle.nn as nn
22

W
wangxinxin08 已提交
23
import paddle.optimizer as optimizer
W
wangguanzhong 已提交
24
import paddle.regularizer as regularizer
W
wangxinxin08 已提交
25
from paddle import cos
26 27 28 29 30

from ppdet.core.workspace import register, serializable

__all__ = ['LearningRate', 'OptimizerBuilder']

K
Kaipeng Deng 已提交
31 32
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
33 34 35 36 37 38 39 40


@serializable
class PiecewiseDecay(object):
    """
    Multi step learning rate decay

    Args:
F
FDInSky 已提交
41
        gamma (float | list): decay factor
42 43 44
        milestones (list): steps at which to decay learning rate
    """

G
Guanghua Yu 已提交
45
    def __init__(self, gamma=[0.1, 0.01], milestones=[8, 11]):
46
        super(PiecewiseDecay, self).__init__()
F
FDInSky 已提交
47 48 49 50 51 52
        if type(gamma) is not list:
            self.gamma = []
            for i in range(len(milestones)):
                self.gamma.append(gamma / 10**i)
        else:
            self.gamma = gamma
53 54
        self.milestones = milestones

G
Guanghua Yu 已提交
55 56 57 58 59
    def __call__(self,
                 base_lr=None,
                 boundary=None,
                 value=None,
                 step_per_epoch=None):
F
FDInSky 已提交
60
        if boundary is not None:
G
Guanghua Yu 已提交
61
            boundary.extend([int(step_per_epoch) * i for i in self.milestones])
62

F
FDInSky 已提交
63 64 65
        if value is not None:
            for i in self.gamma:
                value.append(base_lr * i)
66

W
wangguanzhong 已提交
67
        return optimizer.lr.PiecewiseDecay(boundary, value)
68 69


70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
@serializable
class LinearWarmup(object):
    """
    Warm up learning rate linearly

    Args:
        steps (int): warm up steps
        start_factor (float): initial learning rate factor
    """

    def __init__(self, steps=500, start_factor=1. / 3):
        super(LinearWarmup, self).__init__()
        self.steps = steps
        self.start_factor = start_factor

F
FDInSky 已提交
85 86 87
    def __call__(self, base_lr):
        boundary = []
        value = []
W
wangguanzhong 已提交
88
        for i in range(self.steps + 1):
F
FDInSky 已提交
89 90 91 92 93 94 95
            alpha = i / self.steps
            factor = self.start_factor * (1 - alpha) + alpha
            lr = base_lr * factor
            value.append(lr)
            if i > 0:
                boundary.append(i)
        return boundary, value
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115


@register
class LearningRate(object):
    """
    Learning Rate configuration

    Args:
        base_lr (float): base learning rate
        schedulers (list): learning rate schedulers
    """
    __category__ = 'optim'

    def __init__(self,
                 base_lr=0.01,
                 schedulers=[PiecewiseDecay(), LinearWarmup()]):
        super(LearningRate, self).__init__()
        self.base_lr = base_lr
        self.schedulers = schedulers

G
Guanghua Yu 已提交
116
    def __call__(self, step_per_epoch):
F
FDInSky 已提交
117 118 119 120
        # TODO: split warmup & decay 
        # warmup
        boundary, value = self.schedulers[1](self.base_lr)
        # decay
G
Guanghua Yu 已提交
121 122
        decay_lr = self.schedulers[0](self.base_lr, boundary, value,
                                      step_per_epoch)
F
FDInSky 已提交
123
        return decay_lr
124 125 126 127 128 129 130 131 132 133 134 135 136 137


@register
class OptimizerBuilder():
    """
    Build optimizer handles

    Args:
        regularizer (object): an `Regularizer` instance
        optimizer (object): an `Optimizer` instance
    """
    __category__ = 'optim'

    def __init__(self,
138
                 clip_grad_by_norm=None,
139 140 141 142
                 regularizer={'type': 'L2',
                              'factor': .0001},
                 optimizer={'type': 'Momentum',
                            'momentum': .9}):
143
        self.clip_grad_by_norm = clip_grad_by_norm
144 145 146
        self.regularizer = regularizer
        self.optimizer = optimizer

F
FDInSky 已提交
147
    def __call__(self, learning_rate, params=None):
148
        if self.clip_grad_by_norm is not None:
W
wangxinxin08 已提交
149 150 151 152
            grad_clip = nn.GradientClipByGlobalNorm(
                clip_norm=self.clip_grad_by_norm)
        else:
            grad_clip = None
F
FDInSky 已提交
153

W
wangguanzhong 已提交
154 155 156 157 158 159
        if self.regularizer:
            reg_type = self.regularizer['type'] + 'Decay'
            reg_factor = self.regularizer['factor']
            regularization = getattr(regularizer, reg_type)(reg_factor)
        else:
            regularization = None
F
FDInSky 已提交
160

161 162 163 164 165
        optim_args = self.optimizer.copy()
        optim_type = optim_args['type']
        del optim_args['type']
        op = getattr(optimizer, optim_type)
        return op(learning_rate=learning_rate,
W
wangxinxin08 已提交
166 167 168
                  parameters=params,
                  weight_decay=regularization,
                  grad_clip=grad_clip,
169
                  **optim_args)