manager.py 6.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22

import os
import shutil

W
wuzewu 已提交
23
from paddlehub.common import utils
B
BinLong 已提交
24
from paddlehub.common import stats
W
wuzewu 已提交
25 26
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
B
BinLong 已提交
27
from paddlehub.module import module_desc_pb2
W
wuzewu 已提交
28
import paddlehub as hub
W
wuzewu 已提交
29 30


W
wuzewu 已提交
31
class LocalModuleManager(object):
W
wuzewu 已提交
32
    def __init__(self, module_home=None):
W
wuzewu 已提交
33
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
34
        self.modules_dict = {}
W
wuzewu 已提交
35 36 37 38 39 40 41 42
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
            #TODO(wuzewu): give wanring
            pass

    def check_module_valid(self, module_path):
        #TODO(wuzewu): code
B
BinLong 已提交
43 44 45 46 47 48 49
        info = {}
        try:
            desc_pb_path = os.path.join(module_path, 'module_desc.pb')
            if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
                desc = module_desc_pb2.ModuleDesc()
                with open(desc_pb_path, "rb") as fp:
                    desc.ParseFromString(fp.read())
B
BinLong 已提交
50 51
                info['version'] = desc.attr.map.data["module_info"].map.data[
                    "version"].s
B
BinLong 已提交
52 53 54
        except:
            return False, None
        return True, info
W
wuzewu 已提交
55 56

    def all_modules(self, update=False):
W
wuzewu 已提交
57 58 59
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
60 61
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
B
BinLong 已提交
62
            if os.path.isdir(sub_dir_path):
W
wuzewu 已提交
63
                #TODO(wuzewu): get module name
B
BinLong 已提交
64 65 66
                valid, info = self.check_module_valid(sub_dir_path)
                if valid:
                    module_name = sub_dir_name
B
BinLong 已提交
67 68
                    self.modules_dict[module_name] = (sub_dir_path,
                                                      info['version'])
W
wuzewu 已提交
69
        return self.modules_dict
W
wuzewu 已提交
70

B
BinLong 已提交
71
    def search_module(self, module_name, module_version=None, update=False):
W
wuzewu 已提交
72
        self.all_modules(update=update)
W
wuzewu 已提交
73
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
74

W
wuzewu 已提交
75 76
    def install_module(self, module_name, module_version=None, upgrade=False):
        self.all_modules(update=True)
B
BinLong 已提交
77 78
        module_info = self.modules_dict.get(module_name, None)
        if module_info:
B
BinLong 已提交
79 80
            if not module_version or module_version == self.modules_dict[
                    module_name][1]:
B
BinLong 已提交
81 82
                module_dir = self.modules_dict[module_name][0]
                module_tag = module_name if not module_version else '%s-%s' % (
B
BinLong 已提交
83
                    module_name, module_version)
B
BinLong 已提交
84
                stats.hub_stat(['installed', module_tag])
B
BinLong 已提交
85
                tips = "Module %s already installed in %s" % (module_tag,
B
BinLong 已提交
86
                                                              module_dir)
B
BinLong 已提交
87
                return True, tips, self.modules_dict[module_name]
B
BinLong 已提交
88

W
wuzewu 已提交
89
        search_result = hub.default_hub_server.get_module_url(
W
wuzewu 已提交
90
            module_name, version=module_version)
91
        name = search_result.get('name', None)
W
wuzewu 已提交
92 93
        url = search_result.get('url', None)
        md5_value = search_result.get('md5', None)
94
        installed_module_version = search_result.get('version', None)
W
wuzewu 已提交
95
        #TODO(wuzewu): add compatibility check
96 97
        if not url or (module_version is not None and installed_module_version
                       != module_version) or (name != module_name):
Z
Zeyu Chen 已提交
98
            tips = "Can't find module %s" % module_name
W
wuzewu 已提交
99 100
            if module_version:
                tips += " with version %s" % module_version
101 102
            module_tag = module_name if not module_version else '%s-%s' % (
                module_name, module_version)
B
BinLong 已提交
103
            stats.hub_stat(['install fail', module_tag])
W
wuzewu 已提交
104
            return False, tips, None
W
wuzewu 已提交
105

W
wuzewu 已提交
106
        result, tips, module_zip_file = default_downloader.download_file(
107 108 109
            url=url,
            save_path=hub.CACHE_HOME,
            save_name=module_name,
S
Steffy-zxf 已提交
110 111
            replace=True,
            print_progress=True)
W
wuzewu 已提交
112
        result, tips, module_dir = default_downloader.uncompress(
S
Steffy-zxf 已提交
113 114 115 116
            file=module_zip_file,
            dirname=MODULE_HOME,
            delete_file=True,
            print_progress=True)
W
wuzewu 已提交
117

W
wuzewu 已提交
118
        if module_dir:
B
BinLong 已提交
119 120 121 122 123
            save_path = os.path.join(MODULE_HOME, module_name)
            if os.path.exists(save_path):
                shutil.rmtree(save_path)
            shutil.move(module_dir, save_path)
            module_dir = save_path
W
wuzewu 已提交
124
            tips = "Successfully installed %s" % module_name
B
BinLong 已提交
125
            stats.hub_stat(['install', module_name, url])
126 127
            if installed_module_version:
                tips += "-%s" % installed_module_version
B
BinLong 已提交
128
            return True, tips, (module_dir, installed_module_version)
W
wuzewu 已提交
129 130
        tips = "Download %s-%s failed" % (module_name, module_version)
        return False, tips, module_dir
W
wuzewu 已提交
131

B
BinLong 已提交
132
    def uninstall_module(self, module_name, module_version=None):
W
wuzewu 已提交
133 134
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
135 136
            tips = "%s is not installed" % module_name
            return True, tips
B
BinLong 已提交
137 138
        if module_version and module_version != self.modules_dict[module_name][
                1]:
B
BinLong 已提交
139 140
            tips = "%s-%s is not installed" % (module_name, module_version)
            return True, tips
B
BinLong 已提交
141
        stats.hub_stat(['uninstall', module_name])
W
wuzewu 已提交
142
        tips = "Successfully uninstalled %s" % module_name
B
BinLong 已提交
143 144 145
        if module_version:
            tips += '-%s' % module_version
        module_dir = self.modules_dict[module_name][0]
W
wuzewu 已提交
146
        shutil.rmtree(module_dir)
W
wuzewu 已提交
147
        return True, tips
W
wuzewu 已提交
148 149


W
wuzewu 已提交
150
default_module_manager = LocalModuleManager()