__init__.py 2.7 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19 20 21 22 23
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import paddle

from ppcls.utils import logger

W
WuHaobo 已提交
24 25
from . import optimizer

L
littletomatodonkey 已提交
26 27 28 29 30 31 32 33
__all__ = ['build_optimizer']


def build_lr_scheduler(lr_config, epochs, step_each_epoch):
    from . import learning_rate
    lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
    if 'name' in lr_config:
        lr_name = lr_config.pop('name')
D
dongshuilong 已提交
34 35 36 37 38
        lr = getattr(learning_rate, lr_name)(**lr_config)
        if isinstance(lr, paddle.optimizer.lr.LRScheduler):
            return lr
        else:
            return lr()
L
littletomatodonkey 已提交
39 40 41 42 43
    else:
        lr = lr_config['learning_rate']
    return lr


G
gaotingquan 已提交
44 45
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
L
littletomatodonkey 已提交
46 47 48
    config = copy.deepcopy(config)
    # step1 build lr
    lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
L
littletomatodonkey 已提交
49
    logger.debug("build lr ({}) success..".format(lr))
L
littletomatodonkey 已提交
50 51
    # step2 build regularization
    if 'regularizer' in config and config['regularizer'] is not None:
G
gaotingquan 已提交
52 53 54 55
        if 'weight_decay' in config:
            logger.warning(
                "ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
            )
L
littletomatodonkey 已提交
56 57 58
        reg_config = config.pop('regularizer')
        reg_name = reg_config.pop('name') + 'Decay'
        reg = getattr(paddle.regularizer, reg_name)(**reg_config)
G
gaotingquan 已提交
59 60
        config["weight_decay"] = reg
        logger.debug("build regularizer ({}) success..".format(reg))
L
littletomatodonkey 已提交
61 62 63 64 65 66 67 68 69
    # step3 build optimizer
    optim_name = config.pop('name')
    if 'clip_norm' in config:
        clip_norm = config.pop('clip_norm')
        grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
    else:
        grad_clip = None
    optim = getattr(optimizer, optim_name)(learning_rate=lr,
                                           grad_clip=grad_clip,
G
gaotingquan 已提交
70
                                           **config)(model_list=model_list)
L
littletomatodonkey 已提交
71
    logger.debug("build optimizer ({}) success..".format(optim))
L
littletomatodonkey 已提交
72
    return optim, lr