manager.py 4.0 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
18 19 20 21

import os
import shutil

W
wuzewu 已提交
22 23 24 25
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
import paddlehub as hub
W
wuzewu 已提交
26 27 28


class LocalModuleManager:
W
wuzewu 已提交
29
    def __init__(self, module_home=None):
W
wuzewu 已提交
30
        self.local_modules_dir = module_home if module_home else MODULE_HOME
W
wuzewu 已提交
31
        self.modules_dict = {}
W
wuzewu 已提交
32 33 34 35 36 37 38 39 40 41 42
        if not os.path.exists(self.local_modules_dir):
            utils.mkdir(self.local_modules_dir)
        elif os.path.isfile(self.local_modules_dir):
            #TODO(wuzewu): give wanring
            pass

    def check_module_valid(self, module_path):
        #TODO(wuzewu): code
        return True

    def all_modules(self, update=False):
W
wuzewu 已提交
43 44 45
        if not update and self.modules_dict:
            return self.modules_dict
        self.modules_dict = {}
W
wuzewu 已提交
46 47 48 49 50
        for sub_dir_name in os.listdir(self.local_modules_dir):
            sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
            if os.path.isdir(sub_dir_path) and self.check_module_valid(
                    sub_dir_path):
                #TODO(wuzewu): get module name
W
wuzewu 已提交
51 52
                module_name = sub_dir_name
                self.modules_dict[module_name] = sub_dir_path
W
wuzewu 已提交
53

W
wuzewu 已提交
54
        return self.modules_dict
W
wuzewu 已提交
55 56 57

    def search_module(self, module_name, update=False):
        self.all_modules(update=update)
W
wuzewu 已提交
58
        return self.modules_dict.get(module_name, None)
W
wuzewu 已提交
59

W
wuzewu 已提交
60 61 62 63
    def install_module(self, module_name, module_version=None, upgrade=False):
        self.all_modules(update=True)
        if module_name in self.modules_dict:
            module_dir = self.modules_dict[module_name]
Z
Zeyu Chen 已提交
64 65
            tips = "Module %s already installed in %s" % (module_name,
                                                          module_dir)
W
wuzewu 已提交
66
            return True, tips, module_dir
W
wuzewu 已提交
67
        url = hub.default_hub_server.get_module_url(
W
wuzewu 已提交
68 69 70
            module_name, version=module_version)
        #TODO(wuzewu): add compatibility check
        if not url:
Z
Zeyu Chen 已提交
71
            tips = "Can't find module %s" % module_name
W
wuzewu 已提交
72 73
            if module_version:
                tips += " with version %s" % module_version
W
wuzewu 已提交
74
            return False, tips, None
W
wuzewu 已提交
75

W
wuzewu 已提交
76
        result, tips, module_zip_file = default_downloader.download_file(
77 78 79 80
            url=url,
            save_path=hub.CACHE_HOME,
            save_name=module_name,
            replace=True)
W
wuzewu 已提交
81
        result, tips, module_dir = default_downloader.uncompress(
W
wuzewu 已提交
82 83 84 85 86
            file=module_zip_file, dirname=MODULE_HOME, delete_file=True)

        save_path = os.path.join(MODULE_HOME, module_name)
        shutil.move(module_dir, save_path)
        module_dir = save_path
W
wuzewu 已提交
87

W
wuzewu 已提交
88 89 90 91 92 93 94
        if module_dir:
            tips = "Successfully installed %s" % module_name
            if module_version:
                tips += "-%s" % module_version
            return True, tips, module_dir
        tips = "Download %s-%s failed" % (module_name, module_version)
        return False, tips, module_dir
W
wuzewu 已提交
95 96 97 98

    def uninstall_module(self, module_name):
        self.all_modules(update=True)
        if not module_name in self.modules_dict:
W
wuzewu 已提交
99 100 101
            tips = "%s is not installed" % module_name
            return True, tips
        tips = "Successfully uninstalled %s" % module_name
W
wuzewu 已提交
102 103
        module_dir = self.modules_dict[module_name]
        shutil.rmtree(module_dir)
W
wuzewu 已提交
104
        return True, tips
W
wuzewu 已提交
105 106


W
wuzewu 已提交
107
default_module_manager = LocalModuleManager()