提交 537d3c58 编写于 作者: W wuzewu

Add runable() decorator

上级 e1a8d5c6
...@@ -241,7 +241,7 @@ class RunCommand(BaseCommand): ...@@ -241,7 +241,7 @@ class RunCommand(BaseCommand):
return False return False
if self.module.code_version == "v2": if self.module.code_version == "v2":
results = self.module(argv[1:]) results = self.module.run_func(argv[1:])
else: else:
self.module.check_processor() self.module.check_processor()
self.add_module_config_arg() self.add_module_config_arg()
......
...@@ -113,6 +113,39 @@ class LocalModuleManager(object): ...@@ -113,6 +113,39 @@ class LocalModuleManager(object):
name != module_name): name != module_name):
if default_hub_server._server_check() is False: if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network." tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = default_hub_server.search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle",
"PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else: else:
tips = "Can't find module %s" % module_name tips = "Can't find module %s" % module_name
if module_version: if module_version:
...@@ -158,73 +191,7 @@ class LocalModuleManager(object): ...@@ -158,73 +191,7 @@ class LocalModuleManager(object):
module_dir) module_dir)
return True, tips, self.modules_dict[module_name] return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version
!= module_version) or (name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = default_hub_server.search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle", "PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_dir: if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
if md5_value: if md5_value:
with open( with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"), os.path.join(MODULE_HOME, module_dir, "md5.txt"),
......
...@@ -24,7 +24,7 @@ import functools ...@@ -24,7 +24,7 @@ import functools
import inspect import inspect
import importlib import importlib
import tarfile import tarfile
from collections import defaultdict import six
from shutil import copyfile from shutil import copyfile
import paddle import paddle
...@@ -103,6 +103,23 @@ def create_module(directory, name, author, email, module_type, summary, ...@@ -103,6 +103,23 @@ def create_module(directory, name, author, email, module_type, summary,
os.remove(module_init_2) os.remove(module_init_2)
_module_runable_func = {}
def runable(func):
if six.PY3:
mod = func.__qualname__.split(".")[:-1]
mod = ".".join(mod)
else:
mod = func.im_class.__name__
_module_runable_func[mod] = func.__name__
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
class Module(object): class Module(object):
def __new__(cls, name=None, directory=None, module_dir=None, version=None): def __new__(cls, name=None, directory=None, module_dir=None, version=None):
module = None module = None
...@@ -134,6 +151,12 @@ class Module(object): ...@@ -134,6 +151,12 @@ class Module(object):
version=None): version=None):
if not directory: if not directory:
return return
if self.__class__.__name__ in _module_runable_func:
_run_func_name = _module_runable_func[self.__class__.__name__]
self._run_func = getattr(self, _run_func_name)
else:
self._run_func = None
self._code_version = "v2" self._code_version = "v2"
self._directory = directory self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME) self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
...@@ -192,6 +215,10 @@ class Module(object): ...@@ -192,6 +215,10 @@ class Module(object):
return pymodule.HubModule(directory=directory) return pymodule.HubModule(directory=directory)
return ModuleV1(directory=directory) return ModuleV1(directory=directory)
@property
def run_func(self):
return self._run_func
@property @property
def desc(self): def desc(self):
return self._desc return self._desc
...@@ -224,17 +251,13 @@ class Module(object): ...@@ -224,17 +251,13 @@ class Module(object):
def name(self): def name(self):
return self._name return self._name
@property
def name_prefix(self):
return self._name_prefix
@property @property
def code_version(self): def code_version(self):
return self._code_version return self._code_version
@property @property
def is_runable(self): def is_runable(self):
return False return self._run_func != None
def _initialize(self): def _initialize(self):
pass pass
...@@ -631,7 +654,7 @@ class ModuleV1(Module): ...@@ -631,7 +654,7 @@ class ModuleV1(Module):
return feed_dict, fetch_dict, program return feed_dict, fetch_dict, program
def get_name_prefix(self): def get_name_prefix(self):
return self.name_prefix return self._name_prefix
def get_var_name_with_prefix(self, var_name): def get_var_name_with_prefix(self, var_name):
return self.get_name_prefix() + var_name return self.get_name_prefix() + var_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册