提交 5d27fa77 编写于 作者: W wuzewu

add variable alias

上级 6decfdb9
...@@ -107,6 +107,26 @@ class Module(object): ...@@ -107,6 +107,26 @@ class Module(object):
if op.has_attr("is_test"): if op.has_attr("is_test"):
op._set_attr("is_test", is_test) op._set_attr("is_test", is_test)
def _process_input_output_key(module_desc, signature):
signature = module_desc.sign2var[signature]
feed_dict = {}
fetch_dict = {}
for index, feed in enumerate(signature.feed_desc):
if feed.alias != "":
feed_dict[feed.alias] = feed.var_name
feed_dict[index] = feed.var_name
for index, fetch in enumerate(signature.fetch_desc):
if fetch.alias != "":
fetch_dict[fetch.alias] = fetch.var_name
fetch_dict[index] = fetch.var_name
return feed_dict, fetch_dict
self.config = ModuleConfig(self.module_dir)
self.config.load()
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME) model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
...@@ -114,15 +134,15 @@ class Module(object): ...@@ -114,15 +134,15 @@ class Module(object):
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model( self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
dirname=os.path.join(model_dir, sign_name), executor=self.exe) dirname=os.path.join(model_dir, sign_name), executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name)
# remove feed fetch operator and variable # remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program) ModuleUtils.remove_feed_fetch_op(self.inference_program)
# print("inference_program") # print("inference_program")
# print(self.inference_program) # print(self.inference_program)
print("**feed_target_names**\n{}".format(self.feed_target_names)) print("**feed_target_names**\n{}".format(self.feed_target_names))
print("**fetch_targets**\n{}".format(self.fetch_targets)) print("**fetch_targets**\n{}".format(self.fetch_targets))
self.config = ModuleConfig(self.module_dir)
self.config.load()
self._process_parameter() self._process_parameter()
name_generator_path = ModuleConfig.name_generator_path(self.module_dir) name_generator_path = ModuleConfig.name_generator_path(self.module_dir)
with open(name_generator_path, "rb") as data: with open(name_generator_path, "rb") as data:
...@@ -133,7 +153,15 @@ class Module(object): ...@@ -133,7 +153,15 @@ class Module(object):
_process_op_attr(program=program, is_test=False) _process_op_attr(program=program, is_test=False)
_set_param_trainable(program=program, trainable=trainable) _set_param_trainable(program=program, trainable=trainable)
return self.feed_target_names, self.fetch_targets, program, generator for key, value in feed_dict.items():
var = program.global_block().var(value)
feed_dict[key] = var
for key, value in fetch_dict.items():
var = program.global_block().var(value)
fetch_dict[key] = var
return feed_dict, fetch_dict, program, generator
def get_inference_program(self): def get_inference_program(self):
return self.inference_program return self.inference_program
...@@ -315,13 +343,17 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None): ...@@ -315,13 +343,17 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
var = sign_map[sign.get_name()] var = sign_map[sign.get_name()]
feed_desc = var.feed_desc feed_desc = var.feed_desc
fetch_desc = var.fetch_desc fetch_desc = var.fetch_desc
for input in sign.get_inputs(): feed_names = sign.get_feed_names()
fetch_names = sign.get_fetch_names()
for index, input in enumerate(sign.get_inputs()):
feed_var = feed_desc.add() feed_var = feed_desc.add()
feed_var.var_name = input.name feed_var.var_name = input.name
feed_var.alias = feed_names[index]
for output in sign.get_outputs(): for index, output in enumerate(sign.get_outputs()):
fetch_var = fetch_desc.add() fetch_var = fetch_desc.add()
fetch_var.var_name = output.name fetch_var.var_name = output.name
fetch_var.alias = fetch_names[index]
# save inference program # save inference program
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
......
...@@ -21,11 +21,13 @@ package paddle_hub; ...@@ -21,11 +21,13 @@ package paddle_hub;
// Feed Variable Description // Feed Variable Description
message FeedDesc { message FeedDesc {
string var_name = 1; string var_name = 1;
string alias = 2;
}; };
// Fetch Variable Description // Fetch Variable Description
message FetchDesc { message FetchDesc {
string var_name = 1; string var_name = 1;
string alias = 2;
}; };
// Module Variable // Module Variable
......
...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub', package='paddle_hub',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xd9\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\x0f\n\x07version\x18\x05 \x01(\t\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' '\n\x11module_desc.proto\x12\npaddle_hub\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xd9\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\x0f\n\x07version\x18\x05 \x01(\t\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -44,6 +44,22 @@ _FEEDDESC = _descriptor.Descriptor( ...@@ -44,6 +44,22 @@ _FEEDDESC = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='alias',
full_name='paddle_hub.FeedDesc.alias',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
...@@ -54,7 +70,7 @@ _FEEDDESC = _descriptor.Descriptor( ...@@ -54,7 +70,7 @@ _FEEDDESC = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=33, serialized_start=33,
serialized_end=61, serialized_end=76,
) )
_FETCHDESC = _descriptor.Descriptor( _FETCHDESC = _descriptor.Descriptor(
...@@ -80,6 +96,22 @@ _FETCHDESC = _descriptor.Descriptor( ...@@ -80,6 +96,22 @@ _FETCHDESC = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='alias',
full_name='paddle_hub.FetchDesc.alias',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
...@@ -89,8 +121,8 @@ _FETCHDESC = _descriptor.Descriptor( ...@@ -89,8 +121,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=63, serialized_start=78,
serialized_end=92, serialized_end=122,
) )
_MODULEVAR = _descriptor.Descriptor( _MODULEVAR = _descriptor.Descriptor(
...@@ -141,8 +173,8 @@ _MODULEVAR = _descriptor.Descriptor( ...@@ -141,8 +173,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=94, serialized_start=124,
serialized_end=189, serialized_end=219,
) )
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
...@@ -194,8 +226,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( ...@@ -194,8 +226,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=339, serialized_start=369,
serialized_end=409, serialized_end=439,
) )
_MODULEDESC = _descriptor.Descriptor( _MODULEDESC = _descriptor.Descriptor(
...@@ -296,8 +328,8 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -296,8 +328,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=192, serialized_start=222,
serialized_end=409, serialized_end=439,
) )
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC _MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
......
...@@ -20,11 +20,24 @@ from paddle_hub.utils import to_list ...@@ -20,11 +20,24 @@ from paddle_hub.utils import to_list
class Signature: class Signature:
def __init__(self, name, inputs, outputs): def __init__(self, name, inputs, outputs, feed_names=None,
self.name = name fetch_names=None):
inputs = to_list(inputs) inputs = to_list(inputs)
outputs = to_list(outputs) outputs = to_list(outputs)
if not feed_names:
feed_names = [""] * len(inputs)
feed_names = to_list(feed_names)
assert len(inputs) == len(
feed_names), "the length of feed_names must be same with inputs"
if not fetch_names:
fetch_names = [""] * len(outputs)
fetch_names = to_list(fetch_names)
assert len(outputs) == len(
fetch_names), "the length of fetch_names must be same with outputs"
self.name = name
for item in inputs: for item in inputs:
assert isinstance( assert isinstance(
item, item,
...@@ -37,6 +50,29 @@ class Signature: ...@@ -37,6 +50,29 @@ class Signature:
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.feed_names = feed_names
self.fetch_names = fetch_names
# self.inputs_dict = {}
# for index, value in enumerate(inputs):
# self.inputs_dict[index] = value
# if feed_names:
# for index in range(len(feed_names)):
# key = feed_names[index]
# value = inputs[index]
# self.inputs_dict[key] = value
# self.outputs_dict = {}
# for index, value in enumerate(outputs):
# self.outputs_dict[index] = value
# if feed_names:
# for index in range(len(fetch_names)):
# key = fetch_names[index]
# value = outputs[index]
# self.outputs_dict[key] = value
def get_name(self): def get_name(self):
return self.name return self.name
...@@ -47,7 +83,12 @@ class Signature: ...@@ -47,7 +83,12 @@ class Signature:
def get_outputs(self): def get_outputs(self):
return self.outputs return self.outputs
def get_feed_names(self):
return self.feed_names
def get_fetch_names(self):
return self.fetch_names
def create_signature(name="default", inputs=[], outputs=[]):
def create_signature(name="default", inputs=[], outputs=[]):
return Signature(name=name, inputs=inputs, outputs=outputs) return Signature(name=name, inputs=inputs, outputs=outputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册