提交 537d3c58 编写于 作者: W wuzewu

Add runable() decorator

上级 e1a8d5c6
......@@ -241,7 +241,7 @@ class RunCommand(BaseCommand):
return False
if self.module.code_version == "v2":
results = self.module(argv[1:])
results = self.module.run_func(argv[1:])
else:
self.module.check_processor()
self.add_module_config_arg()
......
......@@ -113,6 +113,39 @@ class LocalModuleManager(object):
name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = default_hub_server.search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle",
"PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else:
tips = "Can't find module %s" % module_name
if module_version:
......@@ -158,73 +191,7 @@ class LocalModuleManager(object):
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version
!= module_version) or (name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = default_hub_server.search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle", "PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
......
......@@ -24,7 +24,7 @@ import functools
import inspect
import importlib
import tarfile
from collections import defaultdict
import six
from shutil import copyfile
import paddle
......@@ -103,6 +103,23 @@ def create_module(directory, name, author, email, module_type, summary,
os.remove(module_init_2)
_module_runable_func = {}
def runable(func):
if six.PY3:
mod = func.__qualname__.split(".")[:-1]
mod = ".".join(mod)
else:
mod = func.im_class.__name__
_module_runable_func[mod] = func.__name__
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
class Module(object):
def __new__(cls, name=None, directory=None, module_dir=None, version=None):
module = None
......@@ -134,6 +151,12 @@ class Module(object):
version=None):
if not directory:
return
if self.__class__.__name__ in _module_runable_func:
_run_func_name = _module_runable_func[self.__class__.__name__]
self._run_func = getattr(self, _run_func_name)
else:
self._run_func = None
self._code_version = "v2"
self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
......@@ -192,6 +215,10 @@ class Module(object):
return pymodule.HubModule(directory=directory)
return ModuleV1(directory=directory)
@property
def run_func(self):
return self._run_func
@property
def desc(self):
return self._desc
......@@ -224,17 +251,13 @@ class Module(object):
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
return self._run_func != None
def _initialize(self):
pass
......@@ -631,7 +654,7 @@ class ModuleV1(Module):
return feed_dict, fetch_dict, program
def get_name_prefix(self):
return self.name_prefix
return self._name_prefix
def get_var_name_with_prefix(self, var_name):
return self.get_name_prefix() + var_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册