module.py 5.7 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17
import inspect
import importlib
W
wuzewu 已提交
18 19
import os
import sys
W
wuzewu 已提交
20 21 22

import paddle.fluid as fluid

W
wuzewu 已提交
23 24
from paddlehub.utils import utils

W
wuzewu 已提交
25

W
wuzewu 已提交
26 27 28 29 30 31 32 33 34
class InvalidHubModule(Exception):
    def __init__(self, directory):
        self.directory = directory

    def __str__(self):
        return '{} is not a valid HubModule'.format(self.directory)


_module_serving_func = {}
W
wuzewu 已提交
35
_module_runnable_func = {}
W
wuzewu 已提交
36 37


W
wuzewu 已提交
38
def runnable(func):
W
wuzewu 已提交
39
    mod = func.__module__ + '.' + inspect.stack()[1][3]
W
wuzewu 已提交
40
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
41 42 43 44 45 46 47

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


走神的阿圆's avatar
走神的阿圆 已提交
48
def serving(func):
W
wuzewu 已提交
49
    mod = func.__module__ + '.' + inspect.stack()[1][3]
走神的阿圆's avatar
走神的阿圆 已提交
50 51 52 53 54 55 56 57
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
58 59 60
class Module(object):
    def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
        if cls.__name__ == 'Module':
W
wuzewu 已提交
61
            if name:
W
wuzewu 已提交
62
                module = cls.init_with_name(name=name, version=version, **kwargs)
W
wuzewu 已提交
63
            elif directory:
W
wuzewu 已提交
64
                module = cls.init_with_directory(directory=directory, **kwargs)
65
        else:
W
wuzewu 已提交
66
            raise RuntimeError()
67

W
wuzewu 已提交
68
        module.directory = directory
W
wuzewu 已提交
69 70 71
        return module

    @classmethod
W
wuzewu 已提交
72
    def load(cls, directory: str):
W
wuzewu 已提交
73 74
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
75

W
wuzewu 已提交
76 77
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
W
wuzewu 已提交
78

W
wuzewu 已提交
79
        sys.path.insert(0, dirname)
W
wuzewu 已提交
80 81 82 83 84 85
        py_module = importlib.import_module('{}.module'.format(basename))

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
W
wuzewu 已提交
86
                break
W
wuzewu 已提交
87 88
        else:
            raise InvalidHubModule(directory)
W
wuzewu 已提交
89
        sys.path.pop(0)
W
wuzewu 已提交
90

W
wuzewu 已提交
91 92
        user_module_cls.directory = directory
        return user_module_cls
W
wuzewu 已提交
93

W
wuzewu 已提交
94 95 96 97 98 99 100 101 102
    @classmethod
    def init_with_name(cls, name: str, version: str = None, **kwargs):
        from paddlehub.module.manager import LocalModuleManager
        manager = LocalModuleManager()
        search_result = manager.search(name)
        user_module_cls = search_result.get('module', None)
        directory = search_result.get('directory', None)
        if not user_module_cls or not user_module_cls.version.match(version):
            user_module_cls = manager.install(name, version)
W
wuzewu 已提交
103

W
wuzewu 已提交
104
        return user_module_cls(**kwargs)
W
wuzewu 已提交
105

W
wuzewu 已提交
106 107 108 109
    @classmethod
    def init_with_directory(cls, directory: str, **kwargs):
        user_module_cls = cls.load(directory)
        return user_module_cls(**kwargs)
W
wuzewu 已提交
110

W
wuzewu 已提交
111

W
wuzewu 已提交
112 113 114 115
class RunModule(object):
    def __init__(self, *args, **kwargs):
        # Avoid module being initialized multiple times
        if '_is_initialize' in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
116
            return
W
wuzewu 已提交
117

W
wuzewu 已提交
118 119 120 121 122
        super(RunModule, self).__init__()
        _run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
        self._run_func = getattr(self, _run_func_name) if _run_func_name else None
        self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
        self._is_initialize = True
W
wuzewu 已提交
123

W
wuzewu 已提交
124 125 126 127 128 129 130 131 132 133
    def _get_func_name(self, current_cls, module_func_dict):
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None
W
wuzewu 已提交
134

W
wuzewu 已提交
135 136 137 138 139 140 141 142 143
    @classmethod
    def get_py_requirements(cls):
        py_module = sys.modules[cls.__module__]
        directory = os.path.dirname(py_module.__file__)
        req_file = os.path.join(directory, 'requirements.txt')
        if not os.path.exists(req_file):
            return []
        with open(req_file, 'r') as file:
            return file.read()
W
wuzewu 已提交
144 145

    @property
W
wuzewu 已提交
146 147
    def is_runnable(self) -> bool:
        return self._run_func != None
W
wuzewu 已提交
148 149


W
wuzewu 已提交
150
sys_type = type
W
wuzewu 已提交
151

W
wuzewu 已提交
152

W
wuzewu 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
def moduleinfo(name: str,
               version: str,
               author: str = None,
               author_email: str = None,
               summary: str = None,
               type: str = None,
               meta=None):
    def _wrapper(cls):
        wrap_cls = cls
        _meta = RunModule if not meta else meta
        if not issubclass(cls, _meta):
            _bases = []
            for _b in cls.__bases__:
                if issubclass(_meta, _b):
                    continue
                _bases.append(_b)
            _bases.append(_meta)
            _bases = tuple(_bases)
            wrap_cls = sys_type(cls.__name__, _bases, dict(cls.__dict__))

        wrap_cls.name = name
        wrap_cls.version = utils.Version(version)
        wrap_cls.author = author
        wrap_cls.author_email = author_email
        wrap_cls.summary = summary
        wrap_cls.type = type
        wrap_cls._hook_by_hub = True
        return wrap_cls
W
wuzewu 已提交
181

W
wuzewu 已提交
182
    return _wrapper