module.py 8.8 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16
import ast
W
wuzewu 已提交
17 18
import inspect
import importlib
W
wuzewu 已提交
19 20
import os
import sys
W
wuzewu 已提交
21
from typing import Callable, Generic, List, Optional
W
wuzewu 已提交
22

W
wuzewu 已提交
23 24
from easydict import EasyDict

W
wuzewu 已提交
25
from paddlehub.utils import log, utils
W
wuzewu 已提交
26
from paddlehub.compat.module.module_v1 import ModuleV1
W
wuzewu 已提交
27

W
wuzewu 已提交
28

W
wuzewu 已提交
29
class InvalidHubModule(Exception):
30
    def __init__(self, directory: str):
W
wuzewu 已提交
31 32 33 34 35 36 37
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
38
_module_runnable_func = {}
W
wuzewu 已提交
39 40


41
def runnable(func: Callable) -> Callable:
W
wuzewu 已提交
42
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
43
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
44 45 46 47 48 49 50

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


51
def serving(func: Callable) -> Callable:
W
wuzewu 已提交
52
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
53 54 55 56 57 58 59 60
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
61
class Module(object):
W
wuzewu 已提交
62 63
    '''
    '''
W
wuzewu 已提交
64

W
wuzewu 已提交
65 66
    def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
        if cls.__name__ == 'Module':
W
wuzewu 已提交
67
            # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
W
wuzewu 已提交
68
            if name:
W
wuzewu 已提交
69
                module = cls.init_with_name(name=name, version=version, **kwargs)
W
wuzewu 已提交
70
            elif directory:
W
wuzewu 已提交
71
                module = cls.init_with_directory(directory=directory, **kwargs)
72
        else:
W
wuzewu 已提交
73
            module = object.__new__(cls)
74

W
wuzewu 已提交
75 76 77
        return module

    @classmethod
78
    def load(cls, directory: str) -> Generic:
W
wuzewu 已提交
79 80
        '''
        '''
W
wuzewu 已提交
81 82
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
83

W
wuzewu 已提交
84
        # If module description file existed, try to load as ModuleV1
W
wuzewu 已提交
85 86
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
W
wuzewu 已提交
87
            return ModuleV1.load(directory)
W
wuzewu 已提交
88

W
wuzewu 已提交
89 90
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
91
        py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
W
wuzewu 已提交
92 93 94 95 96

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
97
                break
W
wuzewu 已提交
98 99
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
100

W
wuzewu 已提交
101 102
        user_module_cls.directory = directory
        return user_module_cls
W
wuzewu 已提交
103

W
wuzewu 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    @classmethod
    def load_module_info(cls, directory: str) -> EasyDict:
        # If is ModuleV1
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
            return ModuleV1.load_module_info(directory)

        # If is ModuleV2
        module_file = os.path.join(directory, 'module.py')
        with open(module_file, 'r') as file:
            pycode = file.read()
            ast_module = ast.parse(pycode)

            for _body in ast_module.body:
                if not isinstance(_body, ast.ClassDef):
                    continue

                for _decorator in _body.decorator_list:
                    if _decorator.func.id != 'moduleinfo':
                        continue

                    info = {key.arg: key.value.s for key in _decorator.keywords}
                    return EasyDict(info)
            else:
                raise InvalidHubModule(directory)

W
wuzewu 已提交
130 131
    @classmethod
    def init_with_name(cls, name: str, version: str = None, **kwargs):
W
wuzewu 已提交
132 133
        '''
        '''
W
wuzewu 已提交
134 135
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
136
        user_module_cls = manager.search(name)
W
wuzewu 已提交
137
        if not user_module_cls or not user_module_cls.version.match(version):
W
wuzewu 已提交
138
            user_module_cls = manager.install(name=name, version=version)
W
wuzewu 已提交
139

W
wuzewu 已提交
140
        directory = manager._get_normalized_path(user_module_cls.name)
W
wuzewu 已提交
141 142 143 144 145 146 147 148 149 150

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
151 152 153 154

        if user_module_cls == ModuleV1:
            return user_module_cls(directory=directory, **kwargs)

W
wuzewu 已提交
155 156
        user_module_cls.directory = directory
        return user_module_cls(**kwargs)
W
wuzewu 已提交
157

W
wuzewu 已提交
158 159
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
W
wuzewu 已提交
160 161
        '''
        '''
W
wuzewu 已提交
162
        user_module_cls = cls.load(directory)
W
wuzewu 已提交
163 164 165 166 167 168 169 170 171 172

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
173 174 175 176

        if user_module_cls == ModuleV1:
            return user_module_cls(directory=directory, **kwargs)

W
wuzewu 已提交
177 178
        user_module_cls.directory = directory
        return user_module_cls(**kwargs)
W
wuzewu 已提交
179

W
wuzewu 已提交
180
    @classmethod
W
wuzewu 已提交
181
    def get_py_requirements(cls):
W
wuzewu 已提交
182 183
        '''
        '''
W
wuzewu 已提交
184 185 186 187 188 189 190
        req_file = os.path.join(cls.directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []

        with open(req_file, 'r') as file:
            return file.read().split('\n')

W
wuzewu 已提交
191

W
wuzewu 已提交
192
class RunModule(object):
W
wuzewu 已提交
193 194 195
    '''
    '''

W
wuzewu 已提交
196 197 198
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
199
            return
W
wuzewu 已提交
200

W
wuzewu 已提交
201 202 203 204 205
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
206

207
    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
W
wuzewu 已提交
208 209 210 211 212 213 214 215 216
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
217

W
wuzewu 已提交
218
    @classmethod
219
    def get_py_requirements(cls) -> List[str]:
W
wuzewu 已提交
220 221
        '''
        '''
W
wuzewu 已提交
222 223 224 225 226 227 228
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
229 230

    @property
W
wuzewu 已提交
231 232
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
233 234


W
wuzewu 已提交
235
sys_type = type
W
wuzewu 已提交
236

W
wuzewu 已提交
237

W
wuzewu 已提交
238 239 240 241 242 243
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
244
               meta=None) -> Callable:
W
wuzewu 已提交
245 246 247
    '''
    '''

248
    def _wrapper(cls: Generic) -> Generic:
W
wuzewu 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
            wrap_cls = sys_type(cls.__name__, _bases, dict(cls.__dict__))

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
269

W
wuzewu 已提交
270
    return _wrapper