# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import traceback class Registry(object): """ The registry that provides name -> object mapping, to support third-party users' custom modules. To create a registry (inside ppgan): .. code-block:: python BACKBONE_REGISTRY = Registry('BACKBONE') To register an object: .. code-block:: python @BACKBONE_REGISTRY.register() class MyBackbone(): ... Or: .. code-block:: python BACKBONE_REGISTRY.register(MyBackbone) """ def __init__(self, name): """ Args: name (str): the name of this registry """ self._name = name self._obj_map = {} def _do_register(self, name, obj): assert ( name not in self._obj_map ), "An object named '{}' was already registered in '{}' registry!".format( name, self._name) self._obj_map[name] = obj def register(self, obj=None, name=None): """ Register the given object under the the name `obj.__name__`. Can be used as either a decorator or not. See docstring of this class for usage. """ if obj is None: # used as a decorator def deco(func_or_class, name=name): if name is None: name = func_or_class.__name__ self._do_register(name, func_or_class) return func_or_class return deco # used as a function call if name is None: name = obj.__name__ self._do_register(name, obj) def get(self, name): ret = self._obj_map.get(name) if ret is None: raise KeyError( "No object named '{}' found in '{}' registry!".format( name, self._name)) return ret def build_from_config(cfg, registry, default_args=None): """Build a class from config dict. Args: cfg (dict): Config dict. It should at least contain the key "name". registry (ppgan.utils.Registry): The registry to search the name from. default_args (dict, optional): Default initialization arguments. Returns: class: The constructed class. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'name' not in cfg: if default_args is None or 'name' not in default_args: raise KeyError( '`cfg` or `default_args` must contain the key "name", ' f'but got {cfg}\n{default_args}') if not isinstance(registry, Registry): raise TypeError('registry must be an mmcv.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') args = cfg.copy() if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) cls_name = args.pop('name') if isinstance(cls_name, str): obj_cls = registry.get(cls_name) elif inspect.isclass(cls_name): obj_cls = obj_cls else: raise TypeError( f'name must be a str or valid name, but got {type(cls_name)}') try: instance = obj_cls(**args) except Exception as e: stack_info = traceback.format_exc() print("Fail to initial class [{}] with error: " "{} and stack:\n{}".format(cls_name, e, str(stack_info))) raise e return instance