提交 ed8afae2 编写于 作者: W wuzewu

cache module in temp dir

上级 48d0b27d
......@@ -23,13 +23,14 @@ import shutil
from functools import cmp_to_key
import tarfile
import paddlehub as hub
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
from paddlehub.common.cml_utils import TablePrinter
from paddlehub.module import module_desc_pb2
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
class LocalModuleManager(object):
......@@ -87,12 +88,97 @@ class LocalModuleManager(object):
extra=None):
md5_value = installed_module_version = None
from_user_dir = True if module_dir else False
if module_name:
self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None)
if module_info:
if not module_version or module_version == self.modules_dict[
module_name][1]:
with tmp_dir() as _dir:
if module_name:
self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None)
if module_info:
if not module_version or module_version == self.modules_dict[
module_name][1]:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (
module_tag, module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.HubServer().get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version
) or (name != module_name):
if hub.HubServer()._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = hub.HubServer().search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle",
"PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=_dir,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(_dir, module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, _dir)
if module_dir:
if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
......@@ -100,114 +186,27 @@ class LocalModuleManager(object):
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.HubServer().get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version) or (
name != module_name):
if hub.HubServer()._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = hub.HubServer().search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle",
"PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
if module_dir:
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path):
shutil.move(save_path)
if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[1]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir:
if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
if module_dir:
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path):
shutil.move(save_path)
if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
shutil.move(module_dir, save_path)
module_dir = save_path
tips = "Successfully installed %s" % module_name
if installed_module_version:
tips += "-%s" % installed_module_version
return True, tips, (module_dir, installed_module_version)
tips = "Download %s-%s failed" % (module_name, module_version)
return False, tips, module_dir
shutil.move(module_dir, save_path)
module_dir = save_path
tips = "Successfully installed %s" % module_name
if installed_module_version:
tips += "-%s" % installed_module_version
return True, tips, (module_dir, installed_module_version)
tips = "Download %s-%s failed" % (module_name, module_version)
return False, tips, module_dir
def uninstall_module(self, module_name, module_version=None):
self.all_modules(update=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册