module.py 5.7 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17
import inspect
import importlib
W
wuzewu 已提交
18 19
import os
import sys
20
from typing import Callable, List, Optional, Generic
W
wuzewu 已提交
21

W
wuzewu 已提交
22 23
from paddlehub.utils import utils

W
wuzewu 已提交
24

W
wuzewu 已提交
25
class InvalidHubModule(Exception):
26
    def __init__(self, directory: str):
W
wuzewu 已提交
27 28 29 30 31 32 33
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
34
_module_runnable_func = {}
W
wuzewu 已提交
35 36


37
def runnable(func: Callable) -> Callable:
W
wuzewu 已提交
38
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
39
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
40 41 42 43 44 45 46

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


47
def serving(func: Callable) -> Callable:
W
wuzewu 已提交
48
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
49 50 51 52 53 54 55 56
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
57 58 59
class Module(object):
    def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
        if cls.__name__ == 'Module':
W
wuzewu 已提交
60
            if name:
W
wuzewu 已提交
61
                module = cls.init_with_name(name=name, version=version, **kwargs)
W
wuzewu 已提交
62
            elif directory:
W
wuzewu 已提交
63
                module = cls.init_with_directory(directory=directory, **kwargs)
64
        else:
W
wuzewu 已提交
65
            raise RuntimeError()
66

W
wuzewu 已提交
67
        module.directory = directory
W
wuzewu 已提交
68 69 70
        return module

    @classmethod
71
    def load(cls, directory: str) -> Generic:
W
wuzewu 已提交
72 73
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
74

W
wuzewu 已提交
75 76
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
77

W
wuzewu 已提交
78
        sys.path.insert(0, dirname)
W
wuzewu 已提交
79 80 81 82 83 84
        py_module = importlib.import_module('{}.module'.format(basename))

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
85
                break
W
wuzewu 已提交
86 87
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
88
        sys.path.pop(0)
W
wuzewu 已提交
89

W
wuzewu 已提交
90 91
        user_module_cls.directory = directory
        return user_module_cls
W
wuzewu 已提交
92

W
wuzewu 已提交
93 94 95 96
    @classmethod
    def init_with_name(cls, name: str, version: str = None, **kwargs):
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
97
        user_module_cls = manager.search(name)
W
wuzewu 已提交
98 99
        if not user_module_cls or not user_module_cls.version.match(version):
            user_module_cls = manager.install(name, version)
W
wuzewu 已提交
100

W
wuzewu 已提交
101
        return user_module_cls(**kwargs)
W
wuzewu 已提交
102

W
wuzewu 已提交
103 104 105 106
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
        user_module_cls = cls.load(directory)
        return user_module_cls(**kwargs)
W
wuzewu 已提交
107

W
wuzewu 已提交
108

W
wuzewu 已提交
109 110 111 112
class RunModule(object):
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
113
            return
W
wuzewu 已提交
114

W
wuzewu 已提交
115 116 117 118 119
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
120

121
    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
W
wuzewu 已提交
122 123 124 125 126 127 128 129 130
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
131

W
wuzewu 已提交
132
    @classmethod
133
    def get_py_requirements(cls) -> List[str]:
W
wuzewu 已提交
134 135 136 137 138 139 140
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
141 142

    @property
W
wuzewu 已提交
143 144
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
145 146


W
wuzewu 已提交
147
sys_type = type
W
wuzewu 已提交
148

W
wuzewu 已提交
149

W
wuzewu 已提交
150 151 152 153 154 155
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
156 157
               meta=None) -> Callable:
    def _wrapper(cls: Generic) -> Generic:
W
wuzewu 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
            wrap_cls = sys_type(cls.__name__, _bases, dict(cls.__dict__))

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
178

W
wuzewu 已提交
179
    return _wrapper