manager.py 10.1 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21

import os
import shutil
22

23
from functools import cmp_to_key
W
wuzewu 已提交
24
import tarfile
W
wuzewu 已提交
25

W
wuzewu 已提交
26
from paddlehub.common import utils
27
from paddlehub.common import srv_utils
W
wuzewu 已提交
28
from paddlehub.common.downloader import default_downloader
S
shenyuhan 已提交
29
from paddlehub.common.hub_server import default_hub_server
W
wuzewu 已提交
30
from paddlehub.common.dir import MODULE_HOME
31
from paddlehub.common.cml_utils import TablePrinter
B
BinLong 已提交
32
from paddlehub.module import module_desc_pb2
W
wuzewu 已提交
33
import paddlehub as hub
K
kinghuin 已提交
34
from paddlehub.common.logger import logger
W
wuzewu 已提交
35 36


W
wuzewu 已提交
37
class LocalModuleManager(object):
W
wuzewu 已提交
38
    def __init__(self, module_home=None):
W
wuzewu 已提交
39
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
40
        self.modules_dict = {}
W
wuzewu 已提交
41 42 43
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
K
kinghuin 已提交
44
            raise ValueError("Module home should be a folder, not a file")
W
wuzewu 已提交
45 46

    def check_module_valid(self, module_path):
B
BinLong 已提交
47 48 49
        try:
            desc_pb_path = os.path.join(module_path, 'module_desc.pb')
            if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
K
kinghuin 已提交
50
                info = {}
B
BinLong 已提交
51 52 53
                desc = module_desc_pb2.ModuleDesc()
                with open(desc_pb_path, "rb") as fp:
                    desc.ParseFromString(fp.read())
B
BinLong 已提交
54 55
                info['version'] = desc.attr.map.data["module_info"].map.data[
                    "version"].s
K
kinghuin 已提交
56 57 58 59 60
                return True, info
            else:
                logger.warning(
                    "%s does not exist, the module will be reinstalled" %
                    desc_pb_path)
B
BinLong 已提交
61
        except:
K
kinghuin 已提交
62 63
            pass
        return False, None
W
wuzewu 已提交
64 65

    def all_modules(self, update=False):
W
wuzewu 已提交
66 67 68
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
69 70
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
B
BinLong 已提交
71 72 73 74
            if os.path.isdir(sub_dir_path):
                valid, info = self.check_module_valid(sub_dir_path)
                if valid:
                    module_name = sub_dir_name
B
BinLong 已提交
75 76
                    self.modules_dict[module_name] = (sub_dir_path,
                                                      info['version'])
W
wuzewu 已提交
77
        return self.modules_dict
W
wuzewu 已提交
78

B
BinLong 已提交
79
    def search_module(self, module_name, module_version=None, update=False):
W
wuzewu 已提交
80
        self.all_modules(update=update)
W
wuzewu 已提交
81
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
82

S
shenyuhan 已提交
83
    def install_module(self,
W
wuzewu 已提交
84 85 86
                       module_name=None,
                       module_dir=None,
                       module_package=None,
S
shenyuhan 已提交
87 88 89
                       module_version=None,
                       upgrade=False,
                       extra=None):
W
wuzewu 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        md5_value = installed_module_version = None
        from_user_dir = True if module_dir else False
        if module_name:
            self.all_modules(update=True)
            module_info = self.modules_dict.get(module_name, None)
            if module_info:
                if not module_version or module_version == self.modules_dict[
                        module_name][1]:
                    module_dir = self.modules_dict[module_name][0]
                    module_tag = module_name if not module_version else '%s-%s' % (
                        module_name, module_version)
                    tips = "Module %s already installed in %s" % (module_tag,
                                                                  module_dir)
                    return True, tips, self.modules_dict[module_name]

            search_result = hub.default_hub_server.get_module_url(
                module_name, version=module_version, extra=extra)
            name = search_result.get('name', None)
            url = search_result.get('url', None)
            md5_value = search_result.get('md5', None)
            installed_module_version = search_result.get('version', None)
            if not url or (module_version is not None
                           and installed_module_version != module_version) or (
                               name != module_name):
                if default_hub_server._server_check() is False:
                    tips = "Request Hub-Server unsuccessfully, please check your network."
W
wuzewu 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
                    return False, tips, None
                module_versions_info = default_hub_server.search_module_info(
                    module_name)
                if module_versions_info is not None and len(
                        module_versions_info) > 0:

                    if utils.is_windows():
                        placeholders = [20, 8, 14, 14]
                    else:
                        placeholders = [30, 8, 16, 16]
                    tp = TablePrinter(
                        titles=[
                            "ResourceName", "Version", "PaddlePaddle",
                            "PaddleHub"
                        ],
                        placeholders=placeholders)
                    module_versions_info.sort(
                        key=cmp_to_key(utils.sort_version_key))
                    for resource_name, resource_version, paddle_version, \
                        hub_version in module_versions_info:
                        colors = ["yellow", None, None, None]

                        tp.add_line(
                            contents=[
                                resource_name, resource_version,
                                utils.strflist_version(paddle_version),
                                utils.strflist_version(hub_version)
                            ],
                            colors=colors)
                    tips = "The version of PaddlePaddle or PaddleHub " \
                           "can not match module, please upgrade your " \
                           "PaddlePaddle or PaddleHub according to the form " \
                           "below." + tp.get_text()
W
wuzewu 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
                else:
                    tips = "Can't find module %s" % module_name
                    if module_version:
                        tips += " with version %s" % module_version
                    module_tag = module_name if not module_version else '%s-%s' % (
                        module_name, module_version)
                return False, tips, None

            result, tips, module_zip_file = default_downloader.download_file(
                url=url,
                save_path=hub.CACHE_HOME,
                save_name=module_name,
                replace=True,
                print_progress=True)
            result, tips, module_dir = default_downloader.uncompress(
                file=module_zip_file,
                dirname=MODULE_HOME,
                delete_file=True,
                print_progress=True)

        if module_package:
            with tarfile.open(module_package, "r:gz") as tar:
                file_names = tar.getnames()
                size = len(file_names) - 1
                module_dir = os.path.split(file_names[0])[0]
                module_dir = os.path.join(hub.CACHE_HOME, module_dir)
                # remove cache
                if os.path.exists(module_dir):
                    shutil.rmtree(module_dir)
                for index, file_name in enumerate(file_names):
                    tar.extract(file_name, hub.CACHE_HOME)

        if module_dir:
            if not module_name:
                module_name = hub.Module(directory=module_dir).name
            self.all_modules(update=False)
            module_info = self.modules_dict.get(module_name, None)
            if module_info:
B
BinLong 已提交
187 188
                module_dir = self.modules_dict[module_name][0]
                module_tag = module_name if not module_version else '%s-%s' % (
B
BinLong 已提交
189
                    module_name, module_version)
B
BinLong 已提交
190
                tips = "Module %s already installed in %s" % (module_tag,
B
BinLong 已提交
191
                                                              module_dir)
B
BinLong 已提交
192
                return True, tips, self.modules_dict[module_name]
B
BinLong 已提交
193

W
wuzewu 已提交
194
        if module_dir:
W
wuzewu 已提交
195 196 197 198 199
            if md5_value:
                with open(
                        os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                        "w") as fp:
                    fp.write(md5_value)
200

B
BinLong 已提交
201 202
            save_path = os.path.join(MODULE_HOME, module_name)
            if os.path.exists(save_path):
W
wuzewu 已提交
203 204 205 206 207
                shutil.move(save_path)
            if from_user_dir:
                shutil.copytree(module_dir, save_path)
            else:
                shutil.move(module_dir, save_path)
B
BinLong 已提交
208
            module_dir = save_path
W
wuzewu 已提交
209
            tips = "Successfully installed %s" % module_name
210 211
            if installed_module_version:
                tips += "-%s" % installed_module_version
B
BinLong 已提交
212
            return True, tips, (module_dir, installed_module_version)
W
wuzewu 已提交
213 214
        tips = "Download %s-%s failed" % (module_name, module_version)
        return False, tips, module_dir
W
wuzewu 已提交
215

B
BinLong 已提交
216
    def uninstall_module(self, module_name, module_version=None):
W
wuzewu 已提交
217 218
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
219 220
            tips = "%s is not installed" % module_name
            return True, tips
B
BinLong 已提交
221 222
        if module_version and module_version != self.modules_dict[module_name][
                1]:
B
BinLong 已提交
223 224
            tips = "%s-%s is not installed" % (module_name, module_version)
            return True, tips
W
wuzewu 已提交
225
        tips = "Successfully uninstalled %s" % module_name
B
BinLong 已提交
226 227 228
        if module_version:
            tips += '-%s' % module_version
        module_dir = self.modules_dict[module_name][0]
W
wuzewu 已提交
229
        shutil.rmtree(module_dir)
W
wuzewu 已提交
230
        return True, tips
W
wuzewu 已提交
231 232


W
wuzewu 已提交
233
default_module_manager = LocalModuleManager()