__init__.py 5.0 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19 20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import paddle
21
from typing import Dict, List
L
littletomatodonkey 已提交
22 23 24

from ppcls.utils import logger

W
WuHaobo 已提交
25 26
from . import optimizer

L
littletomatodonkey 已提交
27 28 29 30 31 32 33 34
__all__ = ['build_optimizer']


def build_lr_scheduler(lr_config, epochs, step_each_epoch):
    from . import learning_rate
    lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
    if 'name' in lr_config:
        lr_name = lr_config.pop('name')
D
dongshuilong 已提交
35 36 37 38 39
        lr = getattr(learning_rate, lr_name)(**lr_config)
        if isinstance(lr, paddle.optimizer.lr.LRScheduler):
            return lr
        else:
            return lr()
L
littletomatodonkey 已提交
40 41 42 43 44
    else:
        lr = lr_config['learning_rate']
    return lr


G
gaotingquan 已提交
45 46
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
H
HydrogenSulfate 已提交
47
    optim_config = copy.deepcopy(config)
48
    if isinstance(optim_config, dict):
H
HydrogenSulfate 已提交
49
        # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
50 51
        optim_name = optim_config.pop("name")
        optim_config: List[Dict[str, Dict]] = [{
52
            optim_name: {
H
HydrogenSulfate 已提交
53
                'scope': "all",
54
                **
55
                optim_config
56 57 58 59
            }
        }]
    optim_list = []
    lr_list = []
H
HydrogenSulfate 已提交
60 61 62
    """NOTE:
    Currently only support optim objets below.
    1. single optimizer config.
H
HydrogenSulfate 已提交
63 64
    2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
    3. loss which has parameters, such as CenterLoss.
H
HydrogenSulfate 已提交
65
    """
66
    for optim_item in optim_config:
H
HydrogenSulfate 已提交
67
        # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
68
        # step1 build lr
H
HydrogenSulfate 已提交
69
        optim_name = list(optim_item.keys())[0]  # get optim_name
H
HydrogenSulfate 已提交
70
        optim_scope = optim_item[optim_name].pop('scope')  # get optim_scope
71 72 73
        optim_cfg = optim_item[optim_name]  # get optim_cfg

        lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
H
HydrogenSulfate 已提交
74
        logger.info("build lr ({}) for scope ({}) success..".format(
H
HydrogenSulfate 已提交
75
            lr, optim_scope))
76 77 78 79 80 81 82 83 84 85
        # step2 build regularization
        if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
            if 'weight_decay' in optim_cfg:
                logger.warning(
                    "ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
                )
            reg_config = optim_cfg.pop('regularizer')
            reg_name = reg_config.pop('name') + 'Decay'
            reg = getattr(paddle.regularizer, reg_name)(**reg_config)
            optim_cfg["weight_decay"] = reg
H
HydrogenSulfate 已提交
86 87
            logger.info("build regularizer ({}) for scope ({}) success..".
                        format(reg, optim_scope))
88 89 90 91 92 93 94
        # step3 build optimizer
        if 'clip_norm' in optim_cfg:
            clip_norm = optim_cfg.pop('clip_norm')
            grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
        else:
            grad_clip = None
        optim_model = []
H
HydrogenSulfate 已提交
95

H
HydrogenSulfate 已提交
96 97 98 99 100 101 102 103
        # for static graph
        if model_list is None:
            optim = getattr(optimizer, optim_name)(
                learning_rate=lr, grad_clip=grad_clip,
                **optim_cfg)(model_list=optim_model)
            return optim, lr

        # for dynamic graph
H
HydrogenSulfate 已提交
104 105 106 107 108 109
        for i in range(len(model_list)):
            if len(model_list[i].parameters()) == 0:
                continue
            if optim_scope == "all":
                # optimizer for all
                optim_model.append(model_list[i])
110
            else:
H
HydrogenSulfate 已提交
111 112 113 114 115 116 117
                if optim_scope.endswith("Loss"):
                    # optimizer for loss
                    for m in model_list[i].sublayers(True):
                        if m.__class__.__name__ == optim_scope:
                            optim_model.append(m)
                else:
                    # opmizer for module in model, such as backbone, neck, head...
H
HydrogenSulfate 已提交
118 119 120
                    if optim_scope == model_list[i].__class__.__name__:
                        optim_model.append(model_list[i])
                    elif hasattr(model_list[i], optim_scope):
H
HydrogenSulfate 已提交
121 122
                        optim_model.append(getattr(model_list[i], optim_scope))

123 124 125
        optim = getattr(optimizer, optim_name)(
            learning_rate=lr, grad_clip=grad_clip,
            **optim_cfg)(model_list=optim_model)
H
HydrogenSulfate 已提交
126
        logger.info("build optimizer ({}) for scope ({}) success..".format(
H
HydrogenSulfate 已提交
127
            optim, optim_scope))
128 129 130
        optim_list.append(optim)
        lr_list.append(lr)
    return optim_list, lr_list