提交 ed8afae2 编写于 作者: W wuzewu

cache module in temp dir

上级 48d0b27d
...@@ -23,13 +23,14 @@ import shutil ...@@ -23,13 +23,14 @@ import shutil
from functools import cmp_to_key from functools import cmp_to_key
import tarfile import tarfile
import paddlehub as hub
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME from paddlehub.common.dir import MODULE_HOME
from paddlehub.common.cml_utils import TablePrinter from paddlehub.common.cml_utils import TablePrinter
from paddlehub.module import module_desc_pb2
import paddlehub as hub
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common import tmp_dir
from paddlehub.module import module_desc_pb2
class LocalModuleManager(object): class LocalModuleManager(object):
...@@ -87,6 +88,7 @@ class LocalModuleManager(object): ...@@ -87,6 +88,7 @@ class LocalModuleManager(object):
extra=None): extra=None):
md5_value = installed_module_version = None md5_value = installed_module_version = None
from_user_dir = True if module_dir else False from_user_dir = True if module_dir else False
with tmp_dir() as _dir:
if module_name: if module_name:
self.all_modules(update=True) self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None) module_info = self.modules_dict.get(module_name, None)
...@@ -96,8 +98,8 @@ class LocalModuleManager(object): ...@@ -96,8 +98,8 @@ class LocalModuleManager(object):
module_dir = self.modules_dict[module_name][0] module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % ( module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version) module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag, tips = "Module %s already installed in %s" % (
module_dir) module_tag, module_dir)
return True, tips, self.modules_dict[module_name] return True, tips, self.modules_dict[module_name]
search_result = hub.HubServer().get_module_url( search_result = hub.HubServer().get_module_url(
...@@ -107,8 +109,8 @@ class LocalModuleManager(object): ...@@ -107,8 +109,8 @@ class LocalModuleManager(object):
md5_value = search_result.get('md5', None) md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None) installed_module_version = search_result.get('version', None)
if not url or (module_version is not None if not url or (module_version is not None
and installed_module_version != module_version) or ( and installed_module_version != module_version
name != module_name): ) or (name != module_name):
if hub.HubServer()._server_check() is False: if hub.HubServer()._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network." tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None return False, tips, None
...@@ -152,7 +154,7 @@ class LocalModuleManager(object): ...@@ -152,7 +154,7 @@ class LocalModuleManager(object):
result, tips, module_zip_file = default_downloader.download_file( result, tips, module_zip_file = default_downloader.download_file(
url=url, url=url,
save_path=hub.CACHE_HOME, save_path=_dir,
save_name=module_name, save_name=module_name,
replace=True, replace=True,
print_progress=True) print_progress=True)
...@@ -166,13 +168,10 @@ class LocalModuleManager(object): ...@@ -166,13 +168,10 @@ class LocalModuleManager(object):
with tarfile.open(module_package, "r:gz") as tar: with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames() file_names = tar.getnames()
size = len(file_names) - 1 size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[1] module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir) module_dir = os.path.join(_dir, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names): for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME) tar.extract(file_name, _dir)
if module_dir: if module_dir:
if not module_name: if not module_name:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册