manager.py 11.3 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21

import os
import shutil
22

23
from functools import cmp_to_key
W
wuzewu 已提交
24
import tarfile
W
wuzewu 已提交
25 26 27
import sys
import importlib
import inspect
W
wuzewu 已提交
28

W
wuzewu 已提交
29
import paddlehub as hub
W
wuzewu 已提交
30 31 32
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
走神的阿圆's avatar
走神的阿圆 已提交
33
from paddlehub.common.cml_utils import paint_modules_info
K
kinghuin 已提交
34
from paddlehub.common.logger import logger
W
wuzewu 已提交
35 36
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
走神的阿圆's avatar
走神的阿圆 已提交
37 38
from paddlehub.version import hub_version as sys_hub_verion
from paddle import __version__ as sys_paddle_version
W
wuzewu 已提交
39 40


W
wuzewu 已提交
41
class LocalModuleManager(object):
W
wuzewu 已提交
42
    def __init__(self, module_home=None):
W
wuzewu 已提交
43
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
44
        self.modules_dict = {}
W
wuzewu 已提交
45 46 47
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
K
kinghuin 已提交
48
            raise ValueError("Module home should be a folder, not a file")
W
wuzewu 已提交
49 50

    def check_module_valid(self, module_path):
B
BinLong 已提交
51 52 53
        try:
            desc_pb_path = os.path.join(module_path, 'module_desc.pb')
            if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
K
kinghuin 已提交
54
                info = {}
B
BinLong 已提交
55 56 57
                desc = module_desc_pb2.ModuleDesc()
                with open(desc_pb_path, "rb") as fp:
                    desc.ParseFromString(fp.read())
B
BinLong 已提交
58 59
                info['version'] = desc.attr.map.data["module_info"].map.data[
                    "version"].s
60 61
                info['name'] = desc.attr.map.data["module_info"].map.data[
                    "name"].s
K
kinghuin 已提交
62 63
                return True, info
            else:
W
wuzewu 已提交
64 65 66 67 68 69 70 71 72 73 74
                module_file = os.path.join(module_path, 'module.py')
                if os.path.exists(module_file):
                    basename = os.path.split(module_path)[-1]
                    dirname = os.path.join(
                        *list(os.path.split(module_path)[:-1]))
                    sys.path.insert(0, dirname)
                    _module = importlib.import_module(
                        "{}.module".format(basename))
                    for _item, _cls in inspect.getmembers(
                            _module, inspect.isclass):
                        _item = _module.__dict__[_item]
W
wuzewu 已提交
75 76
                        _file = os.path.realpath(
                            sys.modules[_item.__module__].__file__)
S
Steffy-zxf 已提交
77 78 79
                        if issubclass(
                                _item,
                                hub.Module) and _file.startwith(module_file):
W
wuzewu 已提交
80 81 82
                            version = _item._version
                            break
                    sys.path.pop(0)
83
                    return True, {'version': version, 'name': _item._name}
K
kinghuin 已提交
84 85 86
                logger.warning(
                    "%s does not exist, the module will be reinstalled" %
                    desc_pb_path)
B
BinLong 已提交
87
        except:
K
kinghuin 已提交
88 89
            pass
        return False, None
W
wuzewu 已提交
90 91

    def all_modules(self, update=False):
W
wuzewu 已提交
92 93 94
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
95 96
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
B
BinLong 已提交
97
            if os.path.isdir(sub_dir_path):
98 99 100 101
                if "-" in sub_dir_path:
                    new_sub_dir_path = sub_dir_path.replace("-", "_")
                    shutil.move(sub_dir_path, new_sub_dir_path)
                    sub_dir_path = new_sub_dir_path
B
BinLong 已提交
102 103
                valid, info = self.check_module_valid(sub_dir_path)
                if valid:
104
                    module_name = info['name']
B
BinLong 已提交
105 106
                    self.modules_dict[module_name] = (sub_dir_path,
                                                      info['version'])
W
wuzewu 已提交
107
        return self.modules_dict
W
wuzewu 已提交
108

B
BinLong 已提交
109
    def search_module(self, module_name, module_version=None, update=False):
W
wuzewu 已提交
110
        self.all_modules(update=update)
W
wuzewu 已提交
111
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
112

S
shenyuhan 已提交
113
    def install_module(self,
W
wuzewu 已提交
114 115 116
                       module_name=None,
                       module_dir=None,
                       module_package=None,
S
shenyuhan 已提交
117 118 119
                       module_version=None,
                       upgrade=False,
                       extra=None):
W
wuzewu 已提交
120 121
        md5_value = installed_module_version = None
        from_user_dir = True if module_dir else False
W
wuzewu 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        with tmp_dir() as _dir:
            if module_name:
                self.all_modules(update=True)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    if not module_version or module_version == self.modules_dict[
                            module_name][1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]

                search_result = hub.HubServer().get_module_url(
                    module_name, version=module_version, extra=extra)
                name = search_result.get('name', None)
                url = search_result.get('url', None)
                md5_value = search_result.get('md5', None)
                installed_module_version = search_result.get('version', None)
                if not url or (module_version is not None
                               and installed_module_version != module_version
                               ) or (name != module_name):
                    if hub.HubServer()._server_check() is False:
                        tips = "Request Hub-Server unsuccessfully, please check your network."
                        return False, tips, None
                    module_versions_info = hub.HubServer().search_module_info(
                        module_name)
走神的阿圆's avatar
走神的阿圆 已提交
150 151 152 153 154 155 156 157 158
                    if module_versions_info is None:
                        tips = "Can't find module %s, please check your spelling." \
                               % (module_name)
                    elif module_version is not None and module_version not in [
                            item[1] for item in module_versions_info
                    ]:
                        tips = "Can't find module %s with version %s, all versions are listed below." \
                               % (module_name, module_version)
                        tips += paint_modules_info(module_versions_info)
W
wuzewu 已提交
159
                    else:
走神的阿圆's avatar
走神的阿圆 已提交
160 161 162 163
                        tips = "The version of PaddlePaddle(%s) or PaddleHub(%s) can not match module, please upgrade your PaddlePaddle or PaddleHub according to the form below." \
                               % (sys_paddle_version, sys_hub_verion)
                        tips += paint_modules_info(module_versions_info)

W
wuzewu 已提交
164 165 166 167 168 169 170 171 172 173
                    return False, tips, None

                result, tips, module_zip_file = default_downloader.download_file(
                    url=url,
                    save_path=_dir,
                    save_name=module_name,
                    replace=True,
                    print_progress=True)
                result, tips, module_dir = default_downloader.uncompress(
                    file=module_zip_file,
W
wuzewu 已提交
174
                    dirname=os.path.join(_dir, "tmp_module"),
W
wuzewu 已提交
175 176 177 178 179 180 181
                    delete_file=True,
                    print_progress=True)

            if module_package:
                with tarfile.open(module_package, "r:gz") as tar:
                    file_names = tar.getnames()
                    size = len(file_names) - 1
K
kinghuin 已提交
182
                    module_dir = os.path.join(_dir, file_names[0])
W
wuzewu 已提交
183 184
                    for index, file_name in enumerate(file_names):
                        tar.extract(file_name, _dir)
185 186 187 188
                    if "-" in module_dir:
                        new_module_dir = module_dir.replace("-", "_")
                        shutil.move(module_dir, new_module_dir)
                        module_dir = new_module_dir
W
wuzewu 已提交
189
                    module_name = hub.Module(directory=module_dir).name
W
wuzewu 已提交
190 191 192 193

            if from_user_dir:
                module_name = hub.Module(directory=module_dir).name
                module_version = hub.Module(directory=module_dir).version
W
wuzewu 已提交
194 195 196
                self.all_modules(update=False)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
W
wuzewu 已提交
197 198 199 200 201 202 203
                    if module_version == module_info[1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]
W
wuzewu 已提交
204

W
wuzewu 已提交
205 206 207 208 209 210 211
            if module_dir:
                if md5_value:
                    with open(
                            os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                            "w") as fp:
                        fp.write(md5_value)

212 213
                save_path = os.path.join(MODULE_HOME,
                                         module_name.replace("-", "_"))
W
wuzewu 已提交
214 215 216 217 218 219 220
                if save_path != module_dir:
                    if os.path.exists(save_path):
                        shutil.rmtree(save_path)
                    if from_user_dir:
                        shutil.copytree(module_dir, save_path)
                    else:
                        shutil.move(module_dir, save_path)
W
wuzewu 已提交
221 222 223 224 225 226 227
                module_dir = save_path
                tips = "Successfully installed %s" % module_name
                if installed_module_version:
                    tips += "-%s" % installed_module_version
                return True, tips, (module_dir, installed_module_version)
            tips = "Download %s-%s failed" % (module_name, module_version)
            return False, tips, module_dir
W
wuzewu 已提交
228

B
BinLong 已提交
229
    def uninstall_module(self, module_name, module_version=None):
W
wuzewu 已提交
230 231
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
232 233
            tips = "%s is not installed" % module_name
            return True, tips
B
BinLong 已提交
234 235
        if module_version and module_version != self.modules_dict[module_name][
                1]:
B
BinLong 已提交
236 237
            tips = "%s-%s is not installed" % (module_name, module_version)
            return True, tips
W
wuzewu 已提交
238
        tips = "Successfully uninstalled %s" % module_name
B
BinLong 已提交
239 240 241
        if module_version:
            tips += '-%s' % module_version
        module_dir = self.modules_dict[module_name][0]
W
wuzewu 已提交
242
        shutil.rmtree(module_dir)
W
wuzewu 已提交
243
        return True, tips
W
wuzewu 已提交
244 245


W
wuzewu 已提交
246
default_module_manager = LocalModuleManager()