module.py 8.5 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16
import ast
W
wuzewu 已提交
17 18
import inspect
import importlib
W
wuzewu 已提交
19 20
import os
import sys
W
wuzewu 已提交
21
from typing import Callable, Generic, List, Optional
W
wuzewu 已提交
22

W
wuzewu 已提交
23 24
from easydict import EasyDict

W
wuzewu 已提交
25
from paddlehub.utils import log, utils
W
wuzewu 已提交
26
from paddlehub.compat.module.module_v1 import ModuleV1
W
wuzewu 已提交
27

W
wuzewu 已提交
28

W
wuzewu 已提交
29
class InvalidHubModule(Exception):
30
    def __init__(self, directory: str):
W
wuzewu 已提交
31 32 33 34 35 36 37
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
38
_module_runnable_func = {}
W
wuzewu 已提交
39 40


41
def runnable(func: Callable) -> Callable:
W
wuzewu 已提交
42
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
43
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
44 45 46 47 48 49 50

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


51
def serving(func: Callable) -> Callable:
W
wuzewu 已提交
52
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
53 54 55 56 57 58 59 60
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
61
class Module(object):
W
wuzewu 已提交
62 63
    '''
    '''
W
wuzewu 已提交
64

W
wuzewu 已提交
65 66
    def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
        if cls.__name__ == 'Module':
W
wuzewu 已提交
67
            # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
W
wuzewu 已提交
68
            if name:
W
wuzewu 已提交
69
                module = cls.init_with_name(name=name, version=version, **kwargs)
W
wuzewu 已提交
70
            elif directory:
W
wuzewu 已提交
71
                module = cls.init_with_directory(directory=directory, **kwargs)
72
        else:
W
wuzewu 已提交
73
            module = object.__new__(cls)
74

W
wuzewu 已提交
75 76 77
        return module

    @classmethod
78
    def load(cls, directory: str) -> Generic:
W
wuzewu 已提交
79 80
        '''
        '''
W
wuzewu 已提交
81 82
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
83

W
wuzewu 已提交
84
        # If module description file existed, try to load as ModuleV1
W
wuzewu 已提交
85 86
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
W
wuzewu 已提交
87
            return ModuleV1.load(directory)
W
wuzewu 已提交
88

W
wuzewu 已提交
89 90
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
91
        py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
W
wuzewu 已提交
92 93 94 95 96

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
97
                break
W
wuzewu 已提交
98 99
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
100

W
wuzewu 已提交
101 102
        user_module_cls.directory = directory
        return user_module_cls
W
wuzewu 已提交
103

W
wuzewu 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    @classmethod
    def load_module_info(cls, directory: str) -> EasyDict:
        # If is ModuleV1
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
            return ModuleV1.load_module_info(directory)

        # If is ModuleV2
        module_file = os.path.join(directory, 'module.py')
        with open(module_file, 'r') as file:
            pycode = file.read()
            ast_module = ast.parse(pycode)

            for _body in ast_module.body:
                if not isinstance(_body, ast.ClassDef):
                    continue

                for _decorator in _body.decorator_list:
                    if _decorator.func.id != 'moduleinfo':
                        continue

                    info = {key.arg: key.value.s for key in _decorator.keywords}
                    return EasyDict(info)
            else:
                raise InvalidHubModule(directory)

W
wuzewu 已提交
130 131
    @classmethod
    def init_with_name(cls, name: str, version: str = None, **kwargs):
W
wuzewu 已提交
132 133
        '''
        '''
W
wuzewu 已提交
134 135
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
136
        user_module_cls = manager.search(name)
W
wuzewu 已提交
137 138
        if not user_module_cls or not user_module_cls.version.match(version):
            user_module_cls = manager.install(name, version)
W
wuzewu 已提交
139

W
wuzewu 已提交
140
        directory = manager._get_normalized_path(user_module_cls.name)
W
wuzewu 已提交
141 142 143 144 145 146 147 148 149 150

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
151
        return user_module_cls(directory=directory, **kwargs)
W
wuzewu 已提交
152

W
wuzewu 已提交
153 154
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
W
wuzewu 已提交
155 156
        '''
        '''
W
wuzewu 已提交
157
        user_module_cls = cls.load(directory)
W
wuzewu 已提交
158 159 160 161 162 163 164 165 166 167 168

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
        return user_module_cls(directory=directory, **kwargs)
W
wuzewu 已提交
169

W
wuzewu 已提交
170
    @classmethod
W
wuzewu 已提交
171
    def get_py_requirements(cls):
W
wuzewu 已提交
172 173
        '''
        '''
W
wuzewu 已提交
174 175 176 177 178 179 180
        req_file = os.path.join(cls.directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []

        with open(req_file, 'r') as file:
            return file.read().split('\n')

W
wuzewu 已提交
181

W
wuzewu 已提交
182
class RunModule(object):
W
wuzewu 已提交
183 184 185
    '''
    '''

W
wuzewu 已提交
186 187 188
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
189
            return
W
wuzewu 已提交
190

W
wuzewu 已提交
191 192 193 194 195
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
196

197
    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
W
wuzewu 已提交
198 199 200 201 202 203 204 205 206
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
207

W
wuzewu 已提交
208
    @classmethod
209
    def get_py_requirements(cls) -> List[str]:
W
wuzewu 已提交
210 211
        '''
        '''
W
wuzewu 已提交
212 213 214 215 216 217 218
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
219 220

    @property
W
wuzewu 已提交
221 222
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
223 224


W
wuzewu 已提交
225
sys_type = type
W
wuzewu 已提交
226

W
wuzewu 已提交
227

W
wuzewu 已提交
228 229 230 231 232 233
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
234
               meta=None) -> Callable:
W
wuzewu 已提交
235 236 237
    '''
    '''

238
    def _wrapper(cls: Generic) -> Generic:
W
wuzewu 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
            wrap_cls = sys_type(cls.__name__, _bases, dict(cls.__dict__))

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
259

W
wuzewu 已提交
260
    return _wrapper