diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 69ad0d866a1d8ba8b0b6422d731bb6373d9dee3a..77808b786c19437a43f51532427d0ae15ae9ffa8 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -276,22 +276,37 @@ class ModuleConfig(object): return os.path.join(meta_path, GENERATOR_FILENAME) -def create_module(sign_arr, program, module_dir=None, word_dict=None): +def create_module(sign_arr, module_dir=None, word_dict=None): """ Create a module from main program """ - assert isinstance( - program, fluid.Program), "program should be instance of fluid.Program" assert sign_arr, "signature array should not be None" + # check all variable + sign_arr = to_list(sign_arr) + program = sign_arr[0].get_inputs()[0].block.program + for sign in sign_arr: + assert isinstance(sign, + Signature), "sign_arr should be list of Signature" + + for input in sign.get_inputs(): + _tmp_program = input.block.program + assert program == _tmp_program, "all the variable should come from the same program" + + for output in sign.get_outputs(): + _tmp_program = output.block.program + assert program == _tmp_program, "all the variable should come from the same program" + + # create module path for saving if module_dir is None: module_dir = os.path.join(".", "hub_module") - # create module path for saving mkdir(module_dir) + # create module pb module_desc = module_desc_pb2.ModuleDesc() module_desc.version = __version__ program = program.clone() + # save asset if word_dict is None: module_desc.contain_assets = False else: @@ -332,11 +347,8 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None): # save signarture info sign_map = module_desc.sign2var - sign_arr = to_list(sign_arr) + program = sign_arr[0].get_inputs()[0].block.program for sign in sign_arr: - assert isinstance(sign, - Signature), "sign_arr should be list of Signature" - if sign.get_name() in sign_map: raise "Error! sign_arr contains repeat signatrue %s" % sign