提交 b1e6b364 编写于 作者: W wuzewu

fix bug

上级 8e9b98f4
......@@ -56,36 +56,10 @@ HUB_VAR_PREFIX = "@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"
_module_signatures_dict = defaultdict(list)
def signature(func):
mod = func.__qualname__.split(".")[:-1]
mod = ".".join(mod)
mod_signs = _module_signatures_dict[mod]
mod_signs.append(func.__name__)
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
def default_signature(func):
mod = func.__qualname__.split(".")[:-1]
mod = ".".join(mod)
mod_signs = _module_signatures_dict[mod]
mod_signs.insert(0, func.__name__)
def _wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _wrapper
def create_module(directory, name, author, email, module_type, summary,
version):
output_file = "{}.{}".format(name, HUB_PACKAGE_SUFFIX)
save_file_name = "{}.{}".format(name, HUB_PACKAGE_SUFFIX)
# record module info and serialize
desc = module_desc_pb2.ModuleDesc()
......@@ -118,7 +92,7 @@ def create_module(directory, name, author, email, module_type, summary,
file.write("")
# package the module
with tarfile.open(output_file, "w:gz") as tar:
with tarfile.open(save_file_name, "w:gz") as tar:
for dirname, _, files in os.walk(directory):
for file in files:
tar.add(os.path.join(dirname, file))
......@@ -141,9 +115,16 @@ class Module(object):
elif directory:
module = cls.init_with_directory(directory=directory)
elif module_dir:
# todo
logger.warning("")
module = cls.init_with_directory(directory=module_dir[0])
logger.warning(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if isinstance(module_dir, list) or isinstance(
module_dir, tuple):
directory = module_dir[0]
version = module_dir[1]
else:
directory = module_dir
module = cls.init_with_directory(directory=directory)
if not module:
module = object.__new__(cls)
......@@ -156,22 +137,12 @@ class Module(object):
version=None):
if not directory:
return
# todo, add comment
self._directory = directory
# todo, add comment
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc()
self._deserialize_desc()
self._signatures = _module_signatures_dict[self.__class__.__name__]
self._default_signature = self._signatures[
0] if self._signatures else None
# todo
self._initlitizer()
with open(self.module_desc_path, "rb") as file:
self._desc.ParseFromString(file.read())
# todo
module_info = self.desc.attr.map.data['module_info']
self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name'])
......@@ -208,7 +179,6 @@ class Module(object):
checker = ModuleChecker(directory)
checker.check()
# todo
module_code_version = checker.module_code_version
if module_code_version == "v2":
basename = os.path.split(directory)[-1]
......@@ -219,16 +189,6 @@ class Module(object):
return pymodule.HubModule(directory=directory)
return ModuleV1(directory=directory)
def _deserialize_desc(self):
with open(self.module_desc_path, "rb") as file:
self._desc.ParseFromString(file.read())
def _serialize_desc(self):
pass
def check_processor(self):
pass
@property
def desc(self):
return self._desc
......@@ -261,30 +221,10 @@ class Module(object):
def name(self):
return self._name
@property
def signatures(self):
return self._signatures
@property
def default_signature(self):
return self._default_signature
@property
def name_prefix(self):
return self._name_prefix
def configs(self):
return []
def data_format(self, signature):
raise NotImplementedError
def __call__(self, signature, data, use_gpu=False, batch_size=1, **kwargs):
raise NotImplementedError
def context(self, inputs=None, program=None, trainable=False, **kwargs):
raise NotImplementedError
class ModuleHelper(object):
def __init__(self, directory):
......@@ -309,15 +249,14 @@ class ModuleHelper(object):
class ModuleV1(Module):
def __init__(self, name=None, directory=None, module_dir=None,
version=None):
# todo, add comment
if not directory:
return
super(ModuleV1, self).__init__(name, directory, module_dir, version)
self.program = None
self.assets = []
self.helper = None
self._signatures = {}
self._default_signature = None
self.signatures = {}
self.default_signature = None
self.processor = None
self.extra_info = {}
......@@ -473,7 +412,7 @@ class ModuleV1(Module):
# recover default signature
default_signature_name = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data['default_signature'])
self._default_signature = self.signatures[
self.default_signature = self.signatures[
default_signature_name].name if default_signature_name else None
# recover module info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册