提交 ed8afae2 编写于 作者: W wuzewu

cache module in temp dir

上级 48d0b27d
......@@ -23,13 +23,14 @@ import shutil
from functools import cmp_to_key
import tarfile
import paddlehub as hub
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME
from paddlehub.common.cml_utils import TablePrinter
from paddlehub.module import module_desc_pb2
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
class LocalModuleManager(object):
......@@ -87,6 +88,7 @@ class LocalModuleManager(object):
extra=None):
md5_value = installed_module_version = None
from_user_dir = True if module_dir else False
with tmp_dir() as _dir:
if module_name:
self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None)
......@@ -96,8 +98,8 @@ class LocalModuleManager(object):
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
tips = "Module %s already installed in %s" % (
module_tag, module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.HubServer().get_module_url(
......@@ -107,8 +109,8 @@ class LocalModuleManager(object):
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version) or (
name != module_name):
and installed_module_version != module_version
) or (name != module_name):
if hub.HubServer()._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
......@@ -152,7 +154,7 @@ class LocalModuleManager(object):
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_path=_dir,
save_name=module_name,
replace=True,
print_progress=True)
......@@ -166,13 +168,10 @@ class LocalModuleManager(object):
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[1]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(_dir, module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
tar.extract(file_name, _dir)
if module_dir:
if not module_name:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册