From 5d27fa77cf98c26be638a5ce9090749ee75473ac Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 16 Jan 2019 14:32:48 +0800 Subject: [PATCH] add variable alias --- paddle_hub/module.py | 44 +++++++++++++++++++++++++---- paddle_hub/module_desc.proto | 2 ++ paddle_hub/module_desc_pb2.py | 52 ++++++++++++++++++++++++++++------- paddle_hub/signature.py | 47 +++++++++++++++++++++++++++++-- 4 files changed, 126 insertions(+), 19 deletions(-) diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 1e8fb0cf..535a37a5 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -107,6 +107,26 @@ class Module(object): if op.has_attr("is_test"): op._set_attr("is_test", is_test) + def _process_input_output_key(module_desc, signature): + signature = module_desc.sign2var[signature] + + feed_dict = {} + fetch_dict = {} + + for index, feed in enumerate(signature.feed_desc): + if feed.alias != "": + feed_dict[feed.alias] = feed.var_name + feed_dict[index] = feed.var_name + + for index, fetch in enumerate(signature.fetch_desc): + if fetch.alias != "": + fetch_dict[fetch.alias] = fetch.var_name + fetch_dict[index] = fetch.var_name + + return feed_dict, fetch_dict + + self.config = ModuleConfig(self.module_dir) + self.config.load() # load paddle inference model place = fluid.CPUPlace() model_dir = os.path.join(self.module_dir, MODEL_DIRNAME) @@ -114,15 +134,15 @@ class Module(object): self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model( dirname=os.path.join(model_dir, sign_name), executor=self.exe) + feed_dict, fetch_dict = _process_input_output_key( + self.config.desc, sign_name) + # remove feed fetch operator and variable ModuleUtils.remove_feed_fetch_op(self.inference_program) # print("inference_program") # print(self.inference_program) print("**feed_target_names**\n{}".format(self.feed_target_names)) print("**fetch_targets**\n{}".format(self.fetch_targets)) - - self.config = ModuleConfig(self.module_dir) - self.config.load() self._process_parameter() name_generator_path = ModuleConfig.name_generator_path(self.module_dir) with open(name_generator_path, "rb") as data: @@ -133,7 +153,15 @@ class Module(object): _process_op_attr(program=program, is_test=False) _set_param_trainable(program=program, trainable=trainable) - return self.feed_target_names, self.fetch_targets, program, generator + for key, value in feed_dict.items(): + var = program.global_block().var(value) + feed_dict[key] = var + + for key, value in fetch_dict.items(): + var = program.global_block().var(value) + fetch_dict[key] = var + + return feed_dict, fetch_dict, program, generator def get_inference_program(self): return self.inference_program @@ -315,13 +343,17 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None): var = sign_map[sign.get_name()] feed_desc = var.feed_desc fetch_desc = var.fetch_desc - for input in sign.get_inputs(): + feed_names = sign.get_feed_names() + fetch_names = sign.get_fetch_names() + for index, input in enumerate(sign.get_inputs()): feed_var = feed_desc.add() feed_var.var_name = input.name + feed_var.alias = feed_names[index] - for output in sign.get_outputs(): + for index, output in enumerate(sign.get_outputs()): fetch_var = fetch_desc.add() fetch_var.var_name = output.name + fetch_var.alias = fetch_names[index] # save inference program exe = fluid.Executor(place=fluid.CPUPlace()) diff --git a/paddle_hub/module_desc.proto b/paddle_hub/module_desc.proto index d2daab11..96245d5b 100644 --- a/paddle_hub/module_desc.proto +++ b/paddle_hub/module_desc.proto @@ -21,11 +21,13 @@ package paddle_hub; // Feed Variable Description message FeedDesc { string var_name = 1; + string alias = 2; }; // Fetch Variable Description message FetchDesc { string var_name = 1; + string alias = 2; }; // Module Variable diff --git a/paddle_hub/module_desc_pb2.py b/paddle_hub/module_desc_pb2.py index 33b6baa1..6ac63c62 100644 --- a/paddle_hub/module_desc_pb2.py +++ b/paddle_hub/module_desc_pb2.py @@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle_hub', syntax='proto3', serialized_pb=_b( - '\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xd9\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\x0f\n\x07version\x18\x05 \x01(\t\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' + '\n\x11module_desc.proto\x12\npaddle_hub\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xd9\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\x0f\n\x07version\x18\x05 \x01(\t\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -44,6 +44,22 @@ _FEEDDESC = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='alias', + full_name='paddle_hub.FeedDesc.alias', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -54,7 +70,7 @@ _FEEDDESC = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=33, - serialized_end=61, + serialized_end=76, ) _FETCHDESC = _descriptor.Descriptor( @@ -80,6 +96,22 @@ _FETCHDESC = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='alias', + full_name='paddle_hub.FetchDesc.alias', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -89,8 +121,8 @@ _FETCHDESC = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=63, - serialized_end=92, + serialized_start=78, + serialized_end=122, ) _MODULEVAR = _descriptor.Descriptor( @@ -141,8 +173,8 @@ _MODULEVAR = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=94, - serialized_end=189, + serialized_start=124, + serialized_end=219, ) _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( @@ -194,8 +226,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=339, - serialized_end=409, + serialized_start=369, + serialized_end=439, ) _MODULEDESC = _descriptor.Descriptor( @@ -296,8 +328,8 @@ _MODULEDESC = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=192, - serialized_end=409, + serialized_start=222, + serialized_end=439, ) _MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC diff --git a/paddle_hub/signature.py b/paddle_hub/signature.py index 5fd42eb9..e9a6d887 100644 --- a/paddle_hub/signature.py +++ b/paddle_hub/signature.py @@ -20,11 +20,24 @@ from paddle_hub.utils import to_list class Signature: - def __init__(self, name, inputs, outputs): - self.name = name + def __init__(self, name, inputs, outputs, feed_names=None, + fetch_names=None): inputs = to_list(inputs) outputs = to_list(outputs) + if not feed_names: + feed_names = [""] * len(inputs) + feed_names = to_list(feed_names) + assert len(inputs) == len( + feed_names), "the length of feed_names must be same with inputs" + + if not fetch_names: + fetch_names = [""] * len(outputs) + fetch_names = to_list(fetch_names) + assert len(outputs) == len( + fetch_names), "the length of fetch_names must be same with outputs" + + self.name = name for item in inputs: assert isinstance( item, @@ -37,6 +50,29 @@ class Signature: self.inputs = inputs self.outputs = outputs + self.feed_names = feed_names + self.fetch_names = fetch_names + + +# self.inputs_dict = {} +# for index, value in enumerate(inputs): +# self.inputs_dict[index] = value + +# if feed_names: +# for index in range(len(feed_names)): +# key = feed_names[index] +# value = inputs[index] +# self.inputs_dict[key] = value + +# self.outputs_dict = {} +# for index, value in enumerate(outputs): +# self.outputs_dict[index] = value + +# if feed_names: +# for index in range(len(fetch_names)): +# key = fetch_names[index] +# value = outputs[index] +# self.outputs_dict[key] = value def get_name(self): return self.name @@ -47,7 +83,12 @@ class Signature: def get_outputs(self): return self.outputs + def get_feed_names(self): + return self.feed_names + + def get_fetch_names(self): + return self.fetch_names -def create_signature(name="default", inputs=[], outputs=[]): +def create_signature(name="default", inputs=[], outputs=[]): return Signature(name=name, inputs=inputs, outputs=outputs) -- GitLab