__init__.py 4.0 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19 20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import paddle
21
from typing import Dict, List
L
littletomatodonkey 已提交
22 23 24

from ppcls.utils import logger

W
WuHaobo 已提交
25 26
from . import optimizer

L
littletomatodonkey 已提交
27 28 29 30 31 32 33 34
__all__ = ['build_optimizer']


def build_lr_scheduler(lr_config, epochs, step_each_epoch):
    from . import learning_rate
    lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
    if 'name' in lr_config:
        lr_name = lr_config.pop('name')
D
dongshuilong 已提交
35 36 37 38 39
        lr = getattr(learning_rate, lr_name)(**lr_config)
        if isinstance(lr, paddle.optimizer.lr.LRScheduler):
            return lr
        else:
            return lr()
L
littletomatodonkey 已提交
40 41 42 43 44
    else:
        lr = lr_config['learning_rate']
    return lr


G
gaotingquan 已提交
45 46
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
L
littletomatodonkey 已提交
47
    config = copy.deepcopy(config)
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    if isinstance(config, dict):
        # convert to [{optim_name1: {scope: xxx, **optim_cfg}}, {optim_name2: {scope: xxx, **optim_cfg}}, ...]
        optim_name = config.Optimizer.pop('name')
        config: List[Dict[str, Dict]] = [{
            optim_name: {
                'scope': config.Arch.name,
                **
                config.Optimizer
            }
        }]
    optim_list = []
    lr_list = []
    for optim_item in config:
        # optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
        # step1 build lr
        optim_name = optim_item.keys()[0]  # get optim_name1
        optim_scope = optim_item[optim_name].pop('scope')  # get scope
        optim_cfg = optim_item[optim_name]  # get optim_cfg

        lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
        logger.debug("build lr ({}) for scope ({}) success..".format(
            lr, optim_scope))
        # step2 build regularization
        if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
            if 'weight_decay' in optim_cfg:
                logger.warning(
                    "ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
                )
            reg_config = optim_cfg.pop('regularizer')
            reg_name = reg_config.pop('name') + 'Decay'
            reg = getattr(paddle.regularizer, reg_name)(**reg_config)
            optim_cfg["weight_decay"] = reg
            logger.debug("build regularizer ({}) success..".format(reg))
        # step3 build optimizer
        if 'clip_norm' in optim_cfg:
            clip_norm = optim_cfg.pop('clip_norm')
            grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
        else:
            grad_clip = None
        optim_model = []
        for i in range(len(model_list)):
            class_name = model_list[i].__class__.__name__
            if class_name == optim_scope:
                optim_model.append(model_list[i])
        assert len(optim_model) == 1 and len(optim_model[0].parameters()) > 0, \
            f"Invalid optim model for optim scope({optim_scope}), number of optim_model={len(optim_model)}, and number of optim_model's params={len(optim_model[0].parameters())}"
        optim = getattr(optimizer, optim_name)(
            learning_rate=lr, grad_clip=grad_clip,
            **optim_cfg)(model_list=optim_model)
        logger.debug("build optimizer ({}) for scope ({}) success..".format(
            optim, optim_scope))
        optim_list.append(optim)
        lr_list.append(lr)
    return optim_list, lr_list