提交 9e425b6e 编写于 作者: W wuzewu

fix bug

上级 10e69bd5
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): ...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print("ERROR: Please specify a module name.\n") print("ERROR: Please specify a module name.\n")
self.help() self.help()
return False return False
extra = {"command": "install"}
if argv[0].endswith("tar.gz") or argv[0].endswith("phm"):
result, tips, module_dir = default_module_manager.install_module(
module_package=argv[0], extra=extra)
elif os.path.exists(argv[0]) and os.path.isdir(argv[0]):
result, tips, module_dir = default_module_manager.install_module(
module_dir=argv[0], extra=extra)
else:
module_name = argv[0] module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split( module_version = None if "==" not in module_name else module_name.split(
"==")[1] "==")[1]
module_name = module_name if "==" not in module_name else module_name.split( module_name = module_name if "==" not in module_name else module_name.split(
"==")[0] "==")[0]
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module( result, tips, module_dir = default_module_manager.install_module(
module_name=module_name, module_version=module_version, extra=extra) module_name=module_name,
module_version=module_version,
extra=extra)
print(tips) print(tips)
return True return True
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import shutil import shutil
import tarfile
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import srv_utils from paddlehub.common import srv_utils
...@@ -77,10 +78,14 @@ class LocalModuleManager(object): ...@@ -77,10 +78,14 @@ class LocalModuleManager(object):
return self.modules_dict.get(module_name, None) return self.modules_dict.get(module_name, None)
def install_module(self, def install_module(self,
module_name, module_name=None,
module_dir=None,
module_package=None,
module_version=None, module_version=None,
upgrade=False, upgrade=False,
extra=None): extra=None):
md5_value = installed_module_version = None
if module_name:
self.all_modules(update=True) self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None) module_info = self.modules_dict.get(module_name, None)
if module_info: if module_info:
...@@ -99,8 +104,9 @@ class LocalModuleManager(object): ...@@ -99,8 +104,9 @@ class LocalModuleManager(object):
url = search_result.get('url', None) url = search_result.get('url', None)
md5_value = search_result.get('md5', None) md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None) installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version if not url or (module_version is not None
!= module_version) or (name != module_name): and installed_module_version != module_version) or (
name != module_name):
if default_hub_server._server_check() is False: if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network." tips = "Request Hub-Server unsuccessfully, please check your network."
else: else:
...@@ -123,8 +129,33 @@ class LocalModuleManager(object): ...@@ -123,8 +129,33 @@ class LocalModuleManager(object):
delete_file=True, delete_file=True,
print_progress=True) print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir: if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"), if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp: "w") as fp:
fp.write(md5_value) fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name) save_path = os.path.join(MODULE_HOME, module_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册