module.py 10.5 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16
import ast
W
wuzewu 已提交
17
import builtins
W
wuzewu 已提交
18 19
import inspect
import importlib
W
wuzewu 已提交
20
import os
21
import re
W
wuzewu 已提交
22
import sys
W
wuzewu 已提交
23
from typing import Callable, Generic, List, Optional
W
wuzewu 已提交
24

W
wuzewu 已提交
25 26
from easydict import EasyDict

27
import paddle
W
wuzewu 已提交
28
from paddlehub.utils import parser, log, utils
29
from paddlehub.compat import paddle_utils
W
wuzewu 已提交
30
from paddlehub.compat.module.module_v1 import ModuleV1
W
wuzewu 已提交
31

W
wuzewu 已提交
32

W
wuzewu 已提交
33
class InvalidHubModule(Exception):
34
    def __init__(self, directory: str):
W
wuzewu 已提交
35 36 37 38 39 40 41
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
42
_module_runnable_func = {}
W
wuzewu 已提交
43 44


45
def runnable(func: Callable) -> Callable:
W
wuzewu 已提交
46
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
47
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
48 49 50 51 52 53 54

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


55
def serving(func: Callable) -> Callable:
W
wuzewu 已提交
56
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
57 58 59 60 61 62 63 64
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
65
class Module(object):
W
wuzewu 已提交
66 67
    '''
    '''
W
wuzewu 已提交
68

W
wuzewu 已提交
69 70 71 72 73 74 75
    def __new__(cls,
                name: str = None,
                directory: str = None,
                version: str = None,
                source: str = None,
                update: bool = False,
                **kwargs):
W
wuzewu 已提交
76
        if cls.__name__ == 'Module':
W
wuzewu 已提交
77
            # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
W
wuzewu 已提交
78
            if name:
W
wuzewu 已提交
79
                module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs)
W
wuzewu 已提交
80
            elif directory:
W
wuzewu 已提交
81
                module = cls.init_with_directory(directory=directory, **kwargs)
82
        else:
W
wuzewu 已提交
83
            module = object.__new__(cls)
84

W
wuzewu 已提交
85 86 87
        return module

    @classmethod
88
    def load(cls, directory: str) -> Generic:
W
wuzewu 已提交
89 90
        '''
        '''
W
wuzewu 已提交
91 92
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
93

W
wuzewu 已提交
94
        # If the module description file existed, try to load as ModuleV1
W
wuzewu 已提交
95 96
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
W
wuzewu 已提交
97
            return ModuleV1.load(directory)
W
wuzewu 已提交
98

W
wuzewu 已提交
99 100
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
101
        py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
W
wuzewu 已提交
102 103 104 105 106

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
107
                break
W
wuzewu 已提交
108 109
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
110

W
wuzewu 已提交
111
        user_module_cls.directory = directory
W
wuzewu 已提交
112 113 114 115 116 117 118 119 120

        source_info_file = os.path.join(directory, '_source_info.yaml')
        if os.path.exists(source_info_file):
            info = parser.yaml_parser.parse(source_info_file)
            user_module_cls.source = info.get('source', '')
            user_module_cls.branch = info.get('branch', '')
        else:
            user_module_cls.source = ''
            user_module_cls.branch = ''
W
wuzewu 已提交
121
        return user_module_cls
W
wuzewu 已提交
122

W
wuzewu 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    @classmethod
    def load_module_info(cls, directory: str) -> EasyDict:
        # If is ModuleV1
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
            return ModuleV1.load_module_info(directory)

        # If is ModuleV2
        module_file = os.path.join(directory, 'module.py')
        with open(module_file, 'r') as file:
            pycode = file.read()
            ast_module = ast.parse(pycode)

            for _body in ast_module.body:
                if not isinstance(_body, ast.ClassDef):
                    continue

                for _decorator in _body.decorator_list:
                    if _decorator.func.id != 'moduleinfo':
                        continue

W
wuzewu 已提交
144
                    info = {key.arg: key.value.s for key in _decorator.keywords if key.arg != 'meta'}
W
wuzewu 已提交
145 146 147 148
                    return EasyDict(info)
            else:
                raise InvalidHubModule(directory)

W
wuzewu 已提交
149
    @classmethod
W
wuzewu 已提交
150 151 152 153 154 155 156
    def init_with_name(cls,
                       name: str,
                       version: str = None,
                       source: str = None,
                       update: bool = False,
                       branch: str = None,
                       **kwargs):
W
wuzewu 已提交
157 158
        '''
        '''
W
wuzewu 已提交
159 160
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
W
wuzewu 已提交
161
        user_module_cls = manager.search(name, source=source, branch=branch)
W
wuzewu 已提交
162
        if not user_module_cls or not user_module_cls.version.match(version):
W
wuzewu 已提交
163
            user_module_cls = manager.install(name=name, version=version, source=source, update=update, branch=branch)
W
wuzewu 已提交
164

W
wuzewu 已提交
165
        directory = manager._get_normalized_path(user_module_cls.name)
W
wuzewu 已提交
166 167 168 169 170 171 172 173 174 175

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
176

177
        if issubclass(user_module_cls, ModuleV1):
W
wuzewu 已提交
178 179
            return user_module_cls(directory=directory, **kwargs)

W
wuzewu 已提交
180 181
        user_module_cls.directory = directory
        return user_module_cls(**kwargs)
W
wuzewu 已提交
182

W
wuzewu 已提交
183 184
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
W
wuzewu 已提交
185 186
        '''
        '''
W
wuzewu 已提交
187
        user_module_cls = cls.load(directory)
W
wuzewu 已提交
188 189 190 191 192 193 194 195 196 197

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
198

199
        if issubclass(user_module_cls, ModuleV1):
W
wuzewu 已提交
200 201
            return user_module_cls(directory=directory, **kwargs)

W
wuzewu 已提交
202 203
        user_module_cls.directory = directory
        return user_module_cls(**kwargs)
W
wuzewu 已提交
204

W
wuzewu 已提交
205
    @classmethod
W
wuzewu 已提交
206
    def get_py_requirements(cls):
W
wuzewu 已提交
207 208
        '''
        '''
W
wuzewu 已提交
209 210 211 212 213 214 215
        req_file = os.path.join(cls.directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []

        with open(req_file, 'r') as file:
            return file.read().split('\n')

W
wuzewu 已提交
216

W
wuzewu 已提交
217
class RunModule(object):
W
wuzewu 已提交
218 219 220
    '''
    '''

W
wuzewu 已提交
221 222 223
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
224
            return
W
wuzewu 已提交
225

W
wuzewu 已提交
226 227 228 229 230
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
231

232
    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
W
wuzewu 已提交
233 234 235 236 237 238 239 240 241
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
242

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    # After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
    # execution of the static graph model to fail, so compatibility protection is required.
    def __getattribute__(self, attr):
        _attr = object.__getattribute__(self, attr)

        # If the acquired attribute is a built-in property of the object, skip it.
        if re.match('__.*__', attr):
            return _attr
        # If the module is a dygraph model, skip it.
        elif isinstance(self, paddle.nn.Layer):
            return _attr
        # If the acquired attribute is not a class method, skip it.
        elif not inspect.ismethod(_attr):
            return _attr

        return paddle_utils.run_in_static_mode(_attr)

W
wuzewu 已提交
260
    @classmethod
261
    def get_py_requirements(cls) -> List[str]:
W
wuzewu 已提交
262 263
        '''
        '''
W
wuzewu 已提交
264 265 266 267 268 269 270
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
271 272

    @property
W
wuzewu 已提交
273 274
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
275 276


W
wuzewu 已提交
277 278 279 280 281 282
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
283
               meta=None) -> Callable:
W
wuzewu 已提交
284 285 286
    '''
    '''

287
    def _wrapper(cls: Generic) -> Generic:
W
wuzewu 已提交
288 289 290 291 292 293 294 295 296 297
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
W
wuzewu 已提交
298
            wrap_cls = builtins.type(cls.__name__, _bases, dict(cls.__dict__))
W
wuzewu 已提交
299 300 301 302 303 304 305 306 307

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
308

W
wuzewu 已提交
309
    return _wrapper