__init__.py 5.3 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19 20
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import paddle
21
from typing import Dict, List
L
littletomatodonkey 已提交
22 23 24

from ppcls.utils import logger

W
WuHaobo 已提交
25 26
from . import optimizer

L
littletomatodonkey 已提交
27 28 29 30 31 32 33 34
__all__ = ['build_optimizer']


def build_lr_scheduler(lr_config, epochs, step_each_epoch):
    from . import learning_rate
    lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
    if 'name' in lr_config:
        lr_name = lr_config.pop('name')
D
dongshuilong 已提交
35 36 37 38 39
        lr = getattr(learning_rate, lr_name)(**lr_config)
        if isinstance(lr, paddle.optimizer.lr.LRScheduler):
            return lr
        else:
            return lr()
L
littletomatodonkey 已提交
40 41 42 43 44
    else:
        lr = lr_config['learning_rate']
    return lr


G
gaotingquan 已提交
45 46
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
H
HydrogenSulfate 已提交
47
    optim_config = copy.deepcopy(config)
48
    if isinstance(optim_config, dict):
H
HydrogenSulfate 已提交
49
        # convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
50 51
        optim_name = optim_config.pop("name")
        optim_config: List[Dict[str, Dict]] = [{
52
            optim_name: {
H
HydrogenSulfate 已提交
53
                'scope': "all",
54
                **
55
                optim_config
56 57 58 59
            }
        }]
    optim_list = []
    lr_list = []
H
HydrogenSulfate 已提交
60 61 62
    """NOTE:
    Currently only support optim objets below.
    1. single optimizer config.
H
HydrogenSulfate 已提交
63 64
    2. model(entire Arch), backbone, neck, head.
    3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py.
H
HydrogenSulfate 已提交
65
    """
66
    for optim_item in optim_config:
H
HydrogenSulfate 已提交
67
        # optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
68
        # step1 build lr
H
HydrogenSulfate 已提交
69
        optim_name = list(optim_item.keys())[0]  # get optim_name
70 71
        optim_scope_list = optim_item[optim_name].pop('scope').split(
            ' ')  # get optim_scope list
72 73 74
        optim_cfg = optim_item[optim_name]  # get optim_cfg

        lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
H
HydrogenSulfate 已提交
75
        logger.info("build lr ({}) for scope ({}) success..".format(
76
            lr.__class__.__name__, optim_scope_list))
77 78 79 80 81 82 83 84 85 86
        # step2 build regularization
        if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
            if 'weight_decay' in optim_cfg:
                logger.warning(
                    "ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
                )
            reg_config = optim_cfg.pop('regularizer')
            reg_name = reg_config.pop('name') + 'Decay'
            reg = getattr(paddle.regularizer, reg_name)(**reg_config)
            optim_cfg["weight_decay"] = reg
H
HydrogenSulfate 已提交
87
            logger.info("build regularizer ({}) for scope ({}) success..".
88
                        format(reg.__class__.__name__, optim_scope_list))
89 90 91 92
        # step3 build optimizer
        if 'clip_norm' in optim_cfg:
            clip_norm = optim_cfg.pop('clip_norm')
            grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
93 94
            logger.info("build gradclip ({}) for scope ({}) success..".format(
                grad_clip.__class__.__name__, optim_scope_list))
95 96 97
        else:
            grad_clip = None
        optim_model = []
H
HydrogenSulfate 已提交
98

H
HydrogenSulfate 已提交
99 100 101 102 103 104 105 106
        # for static graph
        if model_list is None:
            optim = getattr(optimizer, optim_name)(
                learning_rate=lr, grad_clip=grad_clip,
                **optim_cfg)(model_list=optim_model)
            return optim, lr

        # for dynamic graph
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
        for scope in optim_scope_list:
            if scope == "all":
                optim_model += model_list
            elif scope == "model":
                optim_model += [model_list[0], ]
            elif scope in ["backbone", "neck", "head"]:
                optim_model += [getattr(model_list[0], scope, None), ]
            elif scope == "loss":
                optim_model += [model_list[1], ]
            else:
                optim_model += [
                    model_list[1].loss_func[i]
                    for i in range(len(model_list[1].loss_func))
                    if model_list[1].loss_func[i].__class__.__name__ == scope
                ]
        # remove invalid items
H
HydrogenSulfate 已提交
123 124 125 126 127 128
        optim_model = [
            optim_model[i] for i in range(len(optim_model))
            if (optim_model[i] is not None
                ) and (len(optim_model[i].parameters()) > 0)
        ]
        assert len(optim_model) > 0, \
129
            f"optim_model is empty for optim_scope({optim_scope_list})"
130 131 132
        optim = getattr(optimizer, optim_name)(
            learning_rate=lr, grad_clip=grad_clip,
            **optim_cfg)(model_list=optim_model)
H
HydrogenSulfate 已提交
133
        logger.info("build optimizer ({}) for scope ({}) success..".format(
134
            optim.__class__.__name__, optim_scope_list))
135 136 137
        optim_list.append(optim)
        lr_list.append(lr)
    return optim_list, lr_list