提交 5737d422 编写于 作者: W wuzewu

add extra info in module

上级 321b7083
......@@ -59,7 +59,14 @@ def create_module(args):
module_dir=args.model + ".hub_module",
module_info="resources/module_info.yml",
processor=processor.Processor,
assets=assets)
assets=assets,
extra_info={
'excepted_image_width': 224,
'excepted_image_height': 224,
'pretrained_images_mean': [0.485, 0.456, 0.406],
'pretrained_images_std': [0.229, 0.224, 0.225],
'image_channel_order': 'RGB'
})
def main():
......
......@@ -40,6 +40,15 @@ from paddlehub import version
__all__ = ['Module', 'create_module']
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB_%s@"
def set_max_seq_len(program, input_dict):
""" Set """
......@@ -51,26 +60,18 @@ def create_module(sign_arr,
processor=None,
assets=None,
module_info=None,
exe=None):
exe=None,
extra_info=None):
sign_arr = utils.to_list(sign_arr)
module = Module(
signatures=sign_arr,
processor=processor,
assets=assets,
module_info=module_info)
module_info=module_info,
extra_info=extra_info)
module.serialize_to_path(path=module_dir, exe=exe)
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB_%s@"
class ModuleHelper(object):
def __init__(self, module_dir):
self.module_dir = module_dir
......@@ -99,7 +100,8 @@ class Module(object):
signatures=None,
module_info=None,
assets=None,
processor=None):
processor=None,
extra_info=None):
self.desc = module_desc_pb2.ModuleDesc()
self.program = None
self.assets = []
......@@ -108,6 +110,10 @@ class Module(object):
self.default_signature = None
self.module_info = None
self.processor = None
self.extra_info = {} if extra_info is None else extra_info
if not isinstance(self.extra_info, dict):
raise TypeError(
"The extra_info should be an instance of python dict")
# TODO(wuzewu): print more module loading info log
if name:
self._init_with_name(name=name)
......@@ -204,6 +210,7 @@ class Module(object):
self._load_assets()
self._recover_from_desc()
self._generate_sign_attr()
self._generate_extra_info()
self._restore_parameter(self.program)
self._recover_variable_info(self.program)
......@@ -213,6 +220,7 @@ class Module(object):
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
self._generate_extra_info()
def _init_with_program(self, program):
pass
......@@ -261,6 +269,14 @@ class Module(object):
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def get_extra_info(self, key):
return self.extra_info.get(key, None)
def _generate_extra_info(self):
for key in self.extra_info:
self.__dict__["get_%s" % key] = functools.partial(
self.get_extra_info, key=key)
def _generate_module_info(self, module_info=None):
if not module_info:
self.module_info = {}
......@@ -332,6 +348,12 @@ class Module(object):
self.summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary'])
# recover extra info
extra_info = self.desc.attr.map.data['extra_info']
self.extra_info = {}
for key, value in extra_info.map.data.items():
self.extra_info[key] = utils.from_module_attr_to_pyobj(value)
# recover name prefix
self.name_prefix = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data["name_prefix"])
......@@ -398,6 +420,12 @@ class Module(object):
utils.from_pyobj_to_module_attr(self.summary,
module_info.map.data['summary'])
# save extra info
extra_info = attr.map.data['extra_info']
extra_info.type = module_desc_pb2.MAP
for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, **kwargs):
self.check_processor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册