dynamic_import.py 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Modified from espnet(https://github.com/espnet/espnet)
15
import importlib
16
import inspect
17 18
from typing import Any
from typing import Dict
19
from typing import List
20 21
from typing import Text

22 23
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.tensor_utils import has_tensor
24 25 26

logger = Log(__name__).getlog()

27
__all__ = ["dynamic_import", "instance_class"]
28 29 30 31 32 33


def dynamic_import(import_path, alias=dict()):
    """dynamic import module and class

    :param str import_path: syntax 'module_name:class_name'
34
        e.g., 'paddlespeech.s2t.models.u2:U2Model'
35 36 37 38
    :param dict alias: shortcut for registered class
    :return: imported class
    """
    if import_path not in alias and ":" not in import_path:
H
Hui Zhang 已提交
39 40 41 42
        raise ValueError(
            "import_path should be one of {} or "
            'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
            "{}".format(set(alias), import_path))
43 44 45 46 47 48 49 50
    if ":" not in import_path:
        import_path = alias[import_path]

    module_name, objname = import_path.split(":")
    m = importlib.import_module(module_name)
    return getattr(m, objname)


51 52 53 54 55 56
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
    # filter by `valid_keys` and filter `val` is not None
    new_args = {
        key: val
        for key, val in args.items() if (key in valid_keys and val is not None)
    }
H
Hui Zhang 已提交
57 58 59
    return new_args


60 61 62 63
def filter_out_tenosr(args: Dict[Text, Any]):
    return {key: val for key, val in args.items() if not has_tensor(val)}


H
Hui Zhang 已提交
64
def instance_class(module_class, args: Dict[Text, Any]):
65 66 67 68
    valid_keys = inspect.signature(module_class).parameters.keys()
    new_args = filter_valid_args(args, valid_keys)
    logger.info(
        f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
69
    return module_class(**new_args)