manager.py 6.4 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22

import os
import shutil

W
wuzewu 已提交
23
from paddlehub.common import utils
24
from paddlehub.common import srv_utils
W
wuzewu 已提交
25 26
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
B
BinLong 已提交
27
from paddlehub.module import module_desc_pb2
W
wuzewu 已提交
28
import paddlehub as hub
K
kinghuin 已提交
29
from paddlehub.common.logger import logger
W
wuzewu 已提交
30 31


W
wuzewu 已提交
32
class LocalModuleManager(object):
W
wuzewu 已提交
33
    def __init__(self, module_home=None):
W
wuzewu 已提交
34
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
35
        self.modules_dict = {}
W
wuzewu 已提交
36 37 38
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
K
kinghuin 已提交
39
            raise ValueError("Module home should be a folder, not a file")
W
wuzewu 已提交
40 41

    def check_module_valid(self, module_path):
B
BinLong 已提交
42 43 44
        try:
            desc_pb_path = os.path.join(module_path, 'module_desc.pb')
            if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
K
kinghuin 已提交
45
                info = {}
B
BinLong 已提交
46 47 48
                desc = module_desc_pb2.ModuleDesc()
                with open(desc_pb_path, "rb") as fp:
                    desc.ParseFromString(fp.read())
B
BinLong 已提交
49 50
                info['version'] = desc.attr.map.data["module_info"].map.data[
                    "version"].s
K
kinghuin 已提交
51 52 53 54 55
                return True, info
            else:
                logger.warning(
                    "%s does not exist, the module will be reinstalled" %
                    desc_pb_path)
B
BinLong 已提交
56
        except:
K
kinghuin 已提交
57 58
            pass
        return False, None
W
wuzewu 已提交
59 60

    def all_modules(self, update=False):
W
wuzewu 已提交
61 62 63
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
64 65
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
B
BinLong 已提交
66 67 68 69
            if os.path.isdir(sub_dir_path):
                valid, info = self.check_module_valid(sub_dir_path)
                if valid:
                    module_name = sub_dir_name
B
BinLong 已提交
70 71
                    self.modules_dict[module_name] = (sub_dir_path,
                                                      info['version'])
W
wuzewu 已提交
72
        return self.modules_dict
W
wuzewu 已提交
73

B
BinLong 已提交
74
    def search_module(self, module_name, module_version=None, update=False):
W
wuzewu 已提交
75
        self.all_modules(update=update)
W
wuzewu 已提交
76
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
77

S
shenyuhan 已提交
78 79 80 81 82
    def install_module(self,
                       module_name,
                       module_version=None,
                       upgrade=False,
                       extra=None):
W
wuzewu 已提交
83
        self.all_modules(update=True)
B
BinLong 已提交
84 85
        module_info = self.modules_dict.get(module_name, None)
        if module_info:
B
BinLong 已提交
86 87
            if not module_version or module_version == self.modules_dict[
                    module_name][1]:
B
BinLong 已提交
88 89
                module_dir = self.modules_dict[module_name][0]
                module_tag = module_name if not module_version else '%s-%s' % (
B
BinLong 已提交
90
                    module_name, module_version)
B
BinLong 已提交
91
                tips = "Module %s already installed in %s" % (module_tag,
B
BinLong 已提交
92
                                                              module_dir)
B
BinLong 已提交
93
                return True, tips, self.modules_dict[module_name]
B
BinLong 已提交
94

W
wuzewu 已提交
95
        search_result = hub.default_hub_server.get_module_url(
S
shenyuhan 已提交
96
            module_name, version=module_version, extra=extra)
97
        name = search_result.get('name', None)
W
wuzewu 已提交
98 99
        url = search_result.get('url', None)
        md5_value = search_result.get('md5', None)
100
        installed_module_version = search_result.get('version', None)
101 102
        if not url or (module_version is not None and installed_module_version
                       != module_version) or (name != module_name):
Z
Zeyu Chen 已提交
103
            tips = "Can't find module %s" % module_name
W
wuzewu 已提交
104 105
            if module_version:
                tips += " with version %s" % module_version
106 107
            module_tag = module_name if not module_version else '%s-%s' % (
                module_name, module_version)
W
wuzewu 已提交
108
            return False, tips, None
W
wuzewu 已提交
109

W
wuzewu 已提交
110
        result, tips, module_zip_file = default_downloader.download_file(
111 112 113
            url=url,
            save_path=hub.CACHE_HOME,
            save_name=module_name,
S
Steffy-zxf 已提交
114 115
            replace=True,
            print_progress=True)
W
wuzewu 已提交
116
        result, tips, module_dir = default_downloader.uncompress(
S
Steffy-zxf 已提交
117 118 119 120
            file=module_zip_file,
            dirname=MODULE_HOME,
            delete_file=True,
            print_progress=True)
W
wuzewu 已提交
121

W
wuzewu 已提交
122
        if module_dir:
S
fix ci  
shenyuhan 已提交
123 124
            with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                      "w") as fp:
125
                fp.write(md5_value)
B
BinLong 已提交
126 127 128 129 130
            save_path = os.path.join(MODULE_HOME, module_name)
            if os.path.exists(save_path):
                shutil.rmtree(save_path)
            shutil.move(module_dir, save_path)
            module_dir = save_path
W
wuzewu 已提交
131
            tips = "Successfully installed %s" % module_name
132 133
            if installed_module_version:
                tips += "-%s" % installed_module_version
B
BinLong 已提交
134
            return True, tips, (module_dir, installed_module_version)
W
wuzewu 已提交
135 136
        tips = "Download %s-%s failed" % (module_name, module_version)
        return False, tips, module_dir
W
wuzewu 已提交
137

B
BinLong 已提交
138
    def uninstall_module(self, module_name, module_version=None):
W
wuzewu 已提交
139 140
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
141 142
            tips = "%s is not installed" % module_name
            return True, tips
B
BinLong 已提交
143 144
        if module_version and module_version != self.modules_dict[module_name][
                1]:
B
BinLong 已提交
145 146
            tips = "%s-%s is not installed" % (module_name, module_version)
            return True, tips
W
wuzewu 已提交
147
        tips = "Successfully uninstalled %s" % module_name
B
BinLong 已提交
148 149 150
        if module_version:
            tips += '-%s' % module_version
        module_dir = self.modules_dict[module_name][0]
W
wuzewu 已提交
151
        shutil.rmtree(module_dir)
W
wuzewu 已提交
152
        return True, tips
W
wuzewu 已提交
153 154


W
wuzewu 已提交
155
default_module_manager = LocalModuleManager()