manager.py 10.3 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21

import os
import shutil
22

23
from functools import cmp_to_key
W
wuzewu 已提交
24
import tarfile
W
wuzewu 已提交
25

W
wuzewu 已提交
26
import paddlehub as hub
W
wuzewu 已提交
27 28 29
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
30
from paddlehub.common.cml_utils import TablePrinter
K
kinghuin 已提交
31
from paddlehub.common.logger import logger
W
wuzewu 已提交
32 33
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
W
wuzewu 已提交
34 35


W
wuzewu 已提交
36
class LocalModuleManager(object):
W
wuzewu 已提交
37
    def __init__(self, module_home=None):
W
wuzewu 已提交
38
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
39
        self.modules_dict = {}
W
wuzewu 已提交
40 41 42
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
K
kinghuin 已提交
43
            raise ValueError("Module home should be a folder, not a file")
W
wuzewu 已提交
44 45

    def check_module_valid(self, module_path):
B
BinLong 已提交
46 47 48
        try:
            desc_pb_path = os.path.join(module_path, 'module_desc.pb')
            if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
K
kinghuin 已提交
49
                info = {}
B
BinLong 已提交
50 51 52
                desc = module_desc_pb2.ModuleDesc()
                with open(desc_pb_path, "rb") as fp:
                    desc.ParseFromString(fp.read())
B
BinLong 已提交
53 54
                info['version'] = desc.attr.map.data["module_info"].map.data[
                    "version"].s
K
kinghuin 已提交
55 56 57 58 59
                return True, info
            else:
                logger.warning(
                    "%s does not exist, the module will be reinstalled" %
                    desc_pb_path)
B
BinLong 已提交
60
        except:
K
kinghuin 已提交
61 62
            pass
        return False, None
W
wuzewu 已提交
63 64

    def all_modules(self, update=False):
W
wuzewu 已提交
65 66 67
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
68 69
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
B
BinLong 已提交
70 71 72 73
            if os.path.isdir(sub_dir_path):
                valid, info = self.check_module_valid(sub_dir_path)
                if valid:
                    module_name = sub_dir_name
B
BinLong 已提交
74 75
                    self.modules_dict[module_name] = (sub_dir_path,
                                                      info['version'])
W
wuzewu 已提交
76
        return self.modules_dict
W
wuzewu 已提交
77

B
BinLong 已提交
78
    def search_module(self, module_name, module_version=None, update=False):
W
wuzewu 已提交
79
        self.all_modules(update=update)
W
wuzewu 已提交
80
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
81

S
shenyuhan 已提交
82
    def install_module(self,
W
wuzewu 已提交
83 84 85
                       module_name=None,
                       module_dir=None,
                       module_package=None,
S
shenyuhan 已提交
86 87 88
                       module_version=None,
                       upgrade=False,
                       extra=None):
W
wuzewu 已提交
89 90
        md5_value = installed_module_version = None
        from_user_dir = True if module_dir else False
W
wuzewu 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        with tmp_dir() as _dir:
            if module_name:
                self.all_modules(update=True)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    if not module_version or module_version == self.modules_dict[
                            module_name][1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]

                search_result = hub.HubServer().get_module_url(
                    module_name, version=module_version, extra=extra)
                name = search_result.get('name', None)
                url = search_result.get('url', None)
                md5_value = search_result.get('md5', None)
                installed_module_version = search_result.get('version', None)
                if not url or (module_version is not None
                               and installed_module_version != module_version
                               ) or (name != module_name):
                    if hub.HubServer()._server_check() is False:
                        tips = "Request Hub-Server unsuccessfully, please check your network."
                        return False, tips, None
                    module_versions_info = hub.HubServer().search_module_info(
                        module_name)
                    if module_versions_info is not None and len(
                            module_versions_info) > 0:

                        if utils.is_windows():
                            placeholders = [20, 8, 14, 14]
                        else:
                            placeholders = [30, 8, 16, 16]
                        tp = TablePrinter(
                            titles=[
                                "ResourceName", "Version", "PaddlePaddle",
                                "PaddleHub"
                            ],
                            placeholders=placeholders)
                        module_versions_info.sort(
                            key=cmp_to_key(utils.sort_version_key))
                        for resource_name, resource_version, paddle_version, \
                            hub_version in module_versions_info:
                            colors = ["yellow", None, None, None]

                            tp.add_line(
                                contents=[
                                    resource_name, resource_version,
                                    utils.strflist_version(paddle_version),
                                    utils.strflist_version(hub_version)
                                ],
                                colors=colors)
                        tips = "The version of PaddlePaddle or PaddleHub " \
                               "can not match module, please upgrade your " \
                               "PaddlePaddle or PaddleHub according to the form " \
                               "below." + tp.get_text()
                    else:
                        tips = "Can't find module %s" % module_name
                        if module_version:
                            tips += " with version %s" % module_version
                    return False, tips, None

                result, tips, module_zip_file = default_downloader.download_file(
                    url=url,
                    save_path=_dir,
                    save_name=module_name,
                    replace=True,
                    print_progress=True)
                result, tips, module_dir = default_downloader.uncompress(
                    file=module_zip_file,
W
wuzewu 已提交
163
                    dirname=os.path.join(_dir, "tmp_module"),
W
wuzewu 已提交
164 165 166 167 168 169 170
                    delete_file=True,
                    print_progress=True)

            if module_package:
                with tarfile.open(module_package, "r:gz") as tar:
                    file_names = tar.getnames()
                    size = len(file_names) - 1
K
kinghuin 已提交
171
                    module_dir = os.path.join(_dir, file_names[0])
W
wuzewu 已提交
172 173 174
                    for index, file_name in enumerate(file_names):
                        tar.extract(file_name, _dir)
                    module_name = hub.Module(directory=module_dir).name
W
wuzewu 已提交
175 176 177 178

            if from_user_dir:
                module_name = hub.Module(directory=module_dir).name
                module_version = hub.Module(directory=module_dir).version
W
wuzewu 已提交
179 180 181
                self.all_modules(update=False)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
W
wuzewu 已提交
182 183 184 185 186 187 188
                    if module_version == module_info[1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]
W
wuzewu 已提交
189

W
wuzewu 已提交
190 191 192 193 194 195 196 197
            if module_dir:
                if md5_value:
                    with open(
                            os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                            "w") as fp:
                        fp.write(md5_value)

                save_path = os.path.join(MODULE_HOME, module_name)
W
wuzewu 已提交
198 199 200 201 202 203 204
                if save_path != module_dir:
                    if os.path.exists(save_path):
                        shutil.rmtree(save_path)
                    if from_user_dir:
                        shutil.copytree(module_dir, save_path)
                    else:
                        shutil.move(module_dir, save_path)
W
wuzewu 已提交
205 206 207 208 209 210 211
                module_dir = save_path
                tips = "Successfully installed %s" % module_name
                if installed_module_version:
                    tips += "-%s" % installed_module_version
                return True, tips, (module_dir, installed_module_version)
            tips = "Download %s-%s failed" % (module_name, module_version)
            return False, tips, module_dir
W
wuzewu 已提交
212

B
BinLong 已提交
213
    def uninstall_module(self, module_name, module_version=None):
W
wuzewu 已提交
214 215
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
216 217
            tips = "%s is not installed" % module_name
            return True, tips
B
BinLong 已提交
218 219
        if module_version and module_version != self.modules_dict[module_name][
                1]:
B
BinLong 已提交
220 221
            tips = "%s-%s is not installed" % (module_name, module_version)
            return True, tips
W
wuzewu 已提交
222
        tips = "Successfully uninstalled %s" % module_name
B
BinLong 已提交
223 224 225
        if module_version:
            tips += '-%s' % module_version
        module_dir = self.modules_dict[module_name][0]
W
wuzewu 已提交
226
        shutil.rmtree(module_dir)
W
wuzewu 已提交
227
        return True, tips
W
wuzewu 已提交
228 229


W
wuzewu 已提交
230
default_module_manager = LocalModuleManager()