module.py 7.6 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17
import inspect
import importlib
W
wuzewu 已提交
18 19
import os
import sys
W
wuzewu 已提交
20
from typing import Callable, Generic, List, Optional
W
wuzewu 已提交
21

W
wuzewu 已提交
22
from paddlehub.utils import log, utils
W
wuzewu 已提交
23
from paddlehub.compat.module.module_v1 import ModuleV1
W
wuzewu 已提交
24

W
wuzewu 已提交
25

W
wuzewu 已提交
26
class InvalidHubModule(Exception):
27
    def __init__(self, directory: str):
W
wuzewu 已提交
28 29 30 31 32 33 34
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
35
_module_runnable_func = {}
W
wuzewu 已提交
36 37


38
def runnable(func: Callable) -> Callable:
W
wuzewu 已提交
39
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
40
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
41 42 43 44 45 46 47

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


48
def serving(func: Callable) -> Callable:
W
wuzewu 已提交
49
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
50 51 52 53 54 55 56 57
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
58
class Module(object):
W
wuzewu 已提交
59 60
    '''
    '''
W
wuzewu 已提交
61

W
wuzewu 已提交
62 63
    def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
        if cls.__name__ == 'Module':
W
wuzewu 已提交
64
            # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
W
wuzewu 已提交
65
            if name:
W
wuzewu 已提交
66
                module = cls.init_with_name(name=name, version=version, **kwargs)
W
wuzewu 已提交
67
            elif directory:
W
wuzewu 已提交
68
                module = cls.init_with_directory(directory=directory, **kwargs)
69
        else:
W
wuzewu 已提交
70
            module = object.__new__(cls)
71

W
wuzewu 已提交
72 73 74
        return module

    @classmethod
75
    def load(cls, directory: str) -> Generic:
W
wuzewu 已提交
76 77
        '''
        '''
W
wuzewu 已提交
78 79
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
80

W
wuzewu 已提交
81
        # If module description file existed, try to load as ModuleV1
W
wuzewu 已提交
82 83 84 85
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
            return ModuleV1.load(desc_file)

W
wuzewu 已提交
86 87
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
88
        py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
W
wuzewu 已提交
89 90 91 92 93

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
94
                break
W
wuzewu 已提交
95 96
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
97

W
wuzewu 已提交
98 99
        user_module_cls.directory = directory
        return user_module_cls
W
wuzewu 已提交
100

W
wuzewu 已提交
101 102
    @classmethod
    def init_with_name(cls, name: str, version: str = None, **kwargs):
W
wuzewu 已提交
103 104
        '''
        '''
W
wuzewu 已提交
105 106
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
107
        user_module_cls = manager.search(name)
W
wuzewu 已提交
108 109
        if not user_module_cls or not user_module_cls.version.match(version):
            user_module_cls = manager.install(name, version)
W
wuzewu 已提交
110

W
wuzewu 已提交
111
        directory = manager._get_normalized_path(name)
W
wuzewu 已提交
112 113 114 115 116 117 118 119 120 121

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
W
wuzewu 已提交
122
        return user_module_cls(directory=directory, **kwargs)
W
wuzewu 已提交
123

W
wuzewu 已提交
124 125
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
W
wuzewu 已提交
126 127
        '''
        '''
W
wuzewu 已提交
128
        user_module_cls = cls.load(directory)
W
wuzewu 已提交
129 130 131 132 133 134 135 136 137 138 139

        # The HubModule in the old version will use the _initialize method to initialize,
        # this function will be obsolete in a future version
        if hasattr(user_module_cls, '_initialize'):
            log.logger.warning(
                'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
            )
            user_module = user_module_cls(directory=directory)
            user_module._initialize(**kwargs)
            return user_module
        return user_module_cls(directory=directory, **kwargs)
W
wuzewu 已提交
140

W
wuzewu 已提交
141
    @classmethod
W
wuzewu 已提交
142
    def get_py_requirements(cls):
W
wuzewu 已提交
143 144
        '''
        '''
W
wuzewu 已提交
145 146 147 148 149 150 151
        req_file = os.path.join(cls.directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []

        with open(req_file, 'r') as file:
            return file.read().split('\n')

W
wuzewu 已提交
152

W
wuzewu 已提交
153
class RunModule(object):
W
wuzewu 已提交
154 155 156
    '''
    '''

W
wuzewu 已提交
157 158 159
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
160
            return
W
wuzewu 已提交
161

W
wuzewu 已提交
162 163 164 165 166
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
167

168
    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
W
wuzewu 已提交
169 170 171 172 173 174 175 176 177
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
178

W
wuzewu 已提交
179
    @classmethod
180
    def get_py_requirements(cls) -> List[str]:
W
wuzewu 已提交
181 182
        '''
        '''
W
wuzewu 已提交
183 184 185 186 187 188 189
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
190 191

    @property
W
wuzewu 已提交
192 193
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
194 195


W
wuzewu 已提交
196
sys_type = type
W
wuzewu 已提交
197

W
wuzewu 已提交
198

W
wuzewu 已提交
199 200 201 202 203 204
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
205
               meta=None) -> Callable:
W
wuzewu 已提交
206 207 208
    '''
    '''

209
    def _wrapper(cls: Generic) -> Generic:
W
wuzewu 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
            wrap_cls = sys_type(cls.__name__, _bases, dict(cls.__dict__))

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
230

W
wuzewu 已提交
231
    return _wrapper