utils.py 3.5 KB
Newer Older
W
WuHaobo 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
W
WuHaobo 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WuHaobo 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14

W
WuHaobo 已提交
15
import six
W
WuHaobo 已提交
16
import types
H
HydrogenSulfate 已提交
17
import paddle
W
WuHaobo 已提交
18 19
from difflib import SequenceMatcher

W
weishengyu 已提交
20
from . import backbone
H
HydrogenSulfate 已提交
21
from typing import Any, Dict, Union
W
WuHaobo 已提交
22

W
WuHaobo 已提交
23 24 25 26 27 28

def get_architectures():
    """
    get all of model architectures
    """
    names = []
W
weishengyu 已提交
29
    for k, v in backbone.__dict__.items():
W
WuHaobo 已提交
30
        if isinstance(v, (types.FunctionType, six.class_types)):
W
WuHaobo 已提交
31 32 33 34
            names.append(k)
    return names


35
def get_blacklist_model_in_static_mode():
H
HydrogenSulfate 已提交
36 37
    from ppcls.arch.backbone import (distilled_vision_transformer,
                                     vision_transformer)
38 39 40 41
    blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__
    return blacklist


W
WuHaobo 已提交
42
def similar_architectures(name='', names=[], thresh=0.1, topk=10):
W
WuHaobo 已提交
43 44 45 46 47
    """
    inferred similar architectures
    """
    scores = []
    for idx, n in enumerate(names):
W
WuHaobo 已提交
48 49
        if n.startswith('__'):
            continue
W
WuHaobo 已提交
50
        score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
W
WuHaobo 已提交
51 52
        if score > thresh:
            scores.append((idx, score))
W
WuHaobo 已提交
53 54 55
    scores.sort(key=lambda x: x[1], reverse=True)
    similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
    return similar_names
H
HydrogenSulfate 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72


def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]]
                        ) -> Union[None, bool, paddle.ParamAttr]:
    """parse ParamAttr from an dict

    Args:
        ParamAttr_config (Union[bool, Dict[str, Dict]]): ParamAttr_config

    Returns:
        Union[bool, paddle.ParamAttr]: Generated ParamAttr
    """
    if ParamAttr_config is None:
        return None
    if isinstance(ParamAttr_config, bool):
        return ParamAttr_config
    ParamAttr_dict = {}
H
HydrogenSulfate 已提交
73 74 75 76 77 78
    if 'initializer' in ParamAttr_config:
        initializer_cfg = ParamAttr_config.get('initializer')
        if 'name' in initializer_cfg:
            initializer_name = initializer_cfg.pop('name')
            ParamAttr_dict['initializer'] = getattr(
                paddle.nn.initializer, initializer_name)(**initializer_cfg)
H
HydrogenSulfate 已提交
79
        else:
H
HydrogenSulfate 已提交
80
            raise ValueError(f"'name' must specified in initializer_cfg")
H
HydrogenSulfate 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    if 'learning_rate' in ParamAttr_config:
        # NOTE: only support an single value now
        learning_rate_value = ParamAttr_config.get('learning_rate')
        if isinstance(learning_rate_value, (int, float)):
            ParamAttr_dict['learning_rate'] = learning_rate_value
        else:
            raise ValueError(
                f"learning_rate_value must be float or int, but got {type(learning_rate_value)}"
            )
    if 'regularizer' in ParamAttr_config:
        regularizer_cfg = ParamAttr_config.get('regularizer')
        if 'name' in regularizer_cfg:
            # L1Decay or L2Decay
            regularizer_name = regularizer_cfg.pop('name')
            ParamAttr_dict['regularizer'] = getattr(
                paddle.regularizer, regularizer_name)(**regularizer_cfg)
        else:
            raise ValueError(f"'name' must specified in regularizer_cfg")
    return paddle.ParamAttr(**ParamAttr_dict)