提交 3b5c1a81 编写于 作者: W wuzewu

get program from variable when create module

上级 b42065bc
......@@ -276,22 +276,37 @@ class ModuleConfig(object):
return os.path.join(meta_path, GENERATOR_FILENAME)
def create_module(sign_arr, program, module_dir=None, word_dict=None):
def create_module(sign_arr, module_dir=None, word_dict=None):
""" Create a module from main program
"""
assert isinstance(
program, fluid.Program), "program should be instance of fluid.Program"
assert sign_arr, "signature array should not be None"
# check all variable
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs():
_tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs():
_tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program"
# create module path for saving
if module_dir is None:
module_dir = os.path.join(".", "hub_module")
# create module path for saving
mkdir(module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
module_desc.version = __version__
program = program.clone()
# save asset
if word_dict is None:
module_desc.contain_assets = False
else:
......@@ -332,11 +347,8 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
# save signarture info
sign_map = module_desc.sign2var
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册