提交 7db1f61a 编写于 作者: W wuzewu

update module func

上级 1cd11734
......@@ -23,6 +23,7 @@ from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature
from paddle_hub import version
import os
import functools
import paddle
import paddle.fluid as fluid
......@@ -39,6 +40,8 @@ def create_module(sign_arr, module_dir, exe=None):
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
......@@ -48,8 +51,8 @@ class ModuleWrapper:
self.module = module
self.name = name
def __call__(self, trainable=False):
return self.module(self.name, trainable)
def __call__(self, data, config):
return self.module(self.name, data, config)
class ModuleHelper:
......@@ -62,6 +65,15 @@ class ModuleHelper:
def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME)
def processor_path(self):
return os.path.join(self.module_dir, PYTHON_DIR)
def processor_name(self):
return PROCESSOR_NAME
def assets_path(self):
return os.path.join(self.module_dir, ASSETS_DIRNAME)
class Module:
def __init__(self, url=None, module_dir=None, signatures=None, name=None):
......@@ -73,6 +85,7 @@ class Module:
self.assets = []
self.helper = None
self.signatures = {}
self.default_signature = None
if url:
self._init_with_url(url=url)
elif module_dir:
......@@ -87,6 +100,13 @@ class Module:
module_dir = downloader.download_and_uncompress(module_url)
self._init_with_module_file(module_dir)
def _load_processor(self):
import sys
processor_path = self.helper.processor_path()
sys.path.append(processor_path)
processor_name = self.helper.processor_name()
self.processor = __import__(processor_name).Processor(module=self)
def _init_with_module_file(self, module_dir):
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
......@@ -97,6 +117,7 @@ class Module:
self.helper.model_path(), executor=exe)
self._recovery_parameter(self.program)
self._recover_variable_info(self.program)
self._load_processor()
inputs = []
outputs = []
......@@ -174,7 +195,8 @@ class Module:
def _generate_sign_attr(self):
self._check_signatures()
for sign in self.signatures:
self.__dict__[sign] = ModuleWrapper(self, sign)
self.__dict__[sign] = functools.partial(
self.__call__, sign_name=sign)
def _generate_desc(self):
# save fluid Parameter
......@@ -215,7 +237,26 @@ class Module:
fetch_var.var_name = HUB_VAR_PREFIX + output.name
fetch_var.alias = fetch_names[index]
def __call__(self, sign_name, trainable=False):
def __call__(self, sign_name, data, config=None):
feed_dict, fetch_dict, program = self.context(sign_name)
#TODO(wuzewu): more option
program = program.clone(for_test=True)
reader = self.processor.reader(sign_name=sign_name, data_dict=data)
feed_name_list = list(
set([value.name for key, value in feed_dict.items()]))
fetch_list = list(set([value for key, value in fetch_dict.items()]))
with fluid.program_guard(program):
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
for batch in reader():
data_out = exe.run(
feed=feeder.feed(batch),
fetch_list=fetch_list,
return_numpy=False)
self.processor.postprocess(sign_name, data_out, config)
def context(self, sign_name, trainable=False):
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
signature = self.signatures[sign_name]
......@@ -227,6 +268,7 @@ class Module:
self._recovery_parameter(program)
self._recover_variable_info(program)
#TODO(wuzewu): return feed_list and fetch_list directly
feed_dict = {}
fetch_dict = {}
for index, var in enumerate(signature.inputs):
......@@ -243,18 +285,15 @@ class Module:
return feed_dict, fetch_dict, program
def preprocess(self):
pass
def postprocess(self):
pass
def parameters(self):
pass
def parameter_attrs(self):
pass
def default_signature(self):
return self.default_signature
def _check_signatures(self):
assert self.signatures, "signature array should not be None"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册