提交 ff0a843f 编写于 作者: Z Zeyu Chen

update proto to supprot mulitiple signature

上级 5a783950
...@@ -197,6 +197,12 @@ class ModuleConfig(object): ...@@ -197,6 +197,12 @@ class ModuleConfig(object):
self.dict = defaultdict(int) self.dict = defaultdict(int)
self.dict.setdefault(0) self.dict.setdefault(0)
# feed_list
self.feed_list = []
# fetch_list
self.fetch_list = []
def load(self): def load(self):
"""load module config from module dir """load module config from module dir
""" """
...@@ -219,7 +225,9 @@ class ModuleConfig(object): ...@@ -219,7 +225,9 @@ class ModuleConfig(object):
self.dict[w] = int(w_id) self.dict[w] = int(w_id)
def dump(self): def dump(self):
# save module_desc.proto first """
save module_desc.proto first
"""
pb_path = os.path.join(self.module_dir, "module_desc.pb") pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "wb") as fo: with open(pb_path, "wb") as fo:
fo.write(self.desc.SerializeToString()) fo.write(self.desc.SerializeToString())
...@@ -232,6 +240,10 @@ class ModuleConfig(object): ...@@ -232,6 +240,10 @@ class ModuleConfig(object):
w_id = self.dict[w] w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id)) fo.write("{}\t{}\n".format(w, w_id))
def register_input_var(self, var):
var_name = var.name()
self.feed_list.add(var_name)
def save_dict(self, word_dict, dict_name=DICT_NAME): def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
......
...@@ -19,11 +19,11 @@ option optimize_for = LITE_RUNTIME; ...@@ -19,11 +19,11 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub; package paddle_hub;
message InputDesc { message InputDesc {
string name = 1; repeated string name = 1;
}; };
message OutputDesc { message OutputDesc {
string name = 1; repeated string name = 1;
}; };
// A Hub Module is stored in a directory with a file 'paddlehub.pb' // A Hub Module is stored in a directory with a file 'paddlehub.pb'
...@@ -32,14 +32,14 @@ message OutputDesc { ...@@ -32,14 +32,14 @@ message OutputDesc {
message ModuleDesc { message ModuleDesc {
string name = 1; // PaddleHub module name string name = 1; // PaddleHub module name
repeated InputDesc input_desc = 2; // signature to input description
map<string, InputDesc> sign2input = 2;
repeated OutputDesc output_desc = 3; // signature to output description
map<string, OutputDesc> sign2output = 3;
string signature = 4; bool return_numpy = 4;
bool return_numpy = 5; bool contain_assets = 5;
bool contain_assets = 6;
}; };
...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub', package='paddle_hub',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\x19\n\tInputDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1a\n\nOutputDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\"\xb3\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12)\n\ninput_desc\x18\x02 \x03(\x0b\x32\x15.paddle_hub.InputDesc\x12+\n\x0boutput_desc\x18\x03 \x03(\x0b\x32\x16.paddle_hub.OutputDesc\x12\x11\n\tsignature\x18\x04 \x01(\t\x12\x14\n\x0creturn_numpy\x18\x05 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x06 \x01(\x08\x42\x02H\x03\x62\x06proto3' '\n\x11module_desc.proto\x12\npaddle_hub\"\x19\n\tInputDesc\x12\x0c\n\x04name\x18\x01 \x03(\t\"\x1a\n\nOutputDesc\x12\x0c\n\x04name\x18\x01 \x03(\t\"\xd8\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\nsign2input\x18\x02 \x03(\x0b\x32&.paddle_hub.ModuleDesc.Sign2inputEntry\x12<\n\x0bsign2output\x18\x03 \x03(\x0b\x32\'.paddle_hub.ModuleDesc.Sign2outputEntry\x12\x14\n\x0creturn_numpy\x18\x04 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x05 \x01(\x08\x1aH\n\x0fSign2inputEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.InputDesc:\x02\x38\x01\x1aJ\n\x10Sign2outputEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.paddle_hub.OutputDesc:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -35,9 +35,9 @@ _INPUTDESC = _descriptor.Descriptor( ...@@ -35,9 +35,9 @@ _INPUTDESC = _descriptor.Descriptor(
number=1, number=1,
type=9, type=9,
cpp_type=9, cpp_type=9,
label=1, label=3,
has_default_value=False, has_default_value=False,
default_value=_b("").decode('utf-8'), default_value=[],
message_type=None, message_type=None,
enum_type=None, enum_type=None,
containing_type=None, containing_type=None,
...@@ -70,9 +70,9 @@ _OUTPUTDESC = _descriptor.Descriptor( ...@@ -70,9 +70,9 @@ _OUTPUTDESC = _descriptor.Descriptor(
number=1, number=1,
type=9, type=9,
cpp_type=9, cpp_type=9,
label=1, label=3,
has_default_value=False, has_default_value=False,
default_value=_b("").decode('utf-8'), default_value=[],
message_type=None, message_type=None,
enum_type=None, enum_type=None,
containing_type=None, containing_type=None,
...@@ -91,16 +91,16 @@ _OUTPUTDESC = _descriptor.Descriptor( ...@@ -91,16 +91,16 @@ _OUTPUTDESC = _descriptor.Descriptor(
serialized_start=60, serialized_start=60,
serialized_end=86, ) serialized_end=86, )
_MODULEDESC = _descriptor.Descriptor( _MODULEDESC_SIGN2INPUTENTRY = _descriptor.Descriptor(
name='ModuleDesc', name='Sign2inputEntry',
full_name='paddle_hub.ModuleDesc', full_name='paddle_hub.ModuleDesc.Sign2inputEntry',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', name='key',
full_name='paddle_hub.ModuleDesc.name', full_name='paddle_hub.ModuleDesc.Sign2inputEntry.key',
index=0, index=0,
number=1, number=1,
type=9, type=9,
...@@ -115,42 +115,98 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -115,42 +115,98 @@ _MODULEDESC = _descriptor.Descriptor(
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='input_desc', name='value',
full_name='paddle_hub.ModuleDesc.input_desc', full_name='paddle_hub.ModuleDesc.Sign2inputEntry.value',
index=1, index=1,
number=2, number=2,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=1,
has_default_value=False, has_default_value=False,
default_value=[], default_value=None,
message_type=None, message_type=None,
enum_type=None, enum_type=None,
containing_type=None, containing_type=None,
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(),
_b('8\001')),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=285,
serialized_end=357, )
_MODULEDESC_SIGN2OUTPUTENTRY = _descriptor.Descriptor(
name='Sign2outputEntry',
full_name='paddle_hub.ModuleDesc.Sign2outputEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='output_desc', name='key',
full_name='paddle_hub.ModuleDesc.output_desc', full_name='paddle_hub.ModuleDesc.Sign2outputEntry.key',
index=2, index=0,
number=3, number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.ModuleDesc.Sign2outputEntry.value',
index=1,
number=2,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=1,
has_default_value=False, has_default_value=False,
default_value=[], default_value=None,
message_type=None, message_type=None,
enum_type=None, enum_type=None,
containing_type=None, containing_type=None,
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(),
_b('8\001')),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=359,
serialized_end=433, )
_MODULEDESC = _descriptor.Descriptor(
name='ModuleDesc',
full_name='paddle_hub.ModuleDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='signature', name='name',
full_name='paddle_hub.ModuleDesc.signature', full_name='paddle_hub.ModuleDesc.name',
index=3, index=0,
number=4, number=1,
type=9, type=9,
cpp_type=9, cpp_type=9,
label=1, label=1,
...@@ -162,11 +218,43 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -162,11 +218,43 @@ _MODULEDESC = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='sign2input',
full_name='paddle_hub.ModuleDesc.sign2input',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sign2output',
full_name='paddle_hub.ModuleDesc.sign2output',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='return_numpy', name='return_numpy',
full_name='paddle_hub.ModuleDesc.return_numpy', full_name='paddle_hub.ModuleDesc.return_numpy',
index=4, index=3,
number=5, number=4,
type=8, type=8,
cpp_type=7, cpp_type=7,
label=1, label=1,
...@@ -181,8 +269,8 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -181,8 +269,8 @@ _MODULEDESC = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='contain_assets', name='contain_assets',
full_name='paddle_hub.ModuleDesc.contain_assets', full_name='paddle_hub.ModuleDesc.contain_assets',
index=5, index=4,
number=6, number=5,
type=8, type=8,
cpp_type=7, cpp_type=7,
label=1, label=1,
...@@ -196,7 +284,10 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -196,7 +284,10 @@ _MODULEDESC = _descriptor.Descriptor(
options=None), options=None),
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[
_MODULEDESC_SIGN2INPUTENTRY,
_MODULEDESC_SIGN2OUTPUTENTRY,
],
enum_types=[], enum_types=[],
options=None, options=None,
is_extendable=False, is_extendable=False,
...@@ -204,10 +295,16 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -204,10 +295,16 @@ _MODULEDESC = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=89, serialized_start=89,
serialized_end=268, ) serialized_end=433, )
_MODULEDESC.fields_by_name['input_desc'].message_type = _INPUTDESC _MODULEDESC_SIGN2INPUTENTRY.fields_by_name['value'].message_type = _INPUTDESC
_MODULEDESC.fields_by_name['output_desc'].message_type = _OUTPUTDESC _MODULEDESC_SIGN2INPUTENTRY.containing_type = _MODULEDESC
_MODULEDESC_SIGN2OUTPUTENTRY.fields_by_name['value'].message_type = _OUTPUTDESC
_MODULEDESC_SIGN2OUTPUTENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name[
'sign2input'].message_type = _MODULEDESC_SIGN2INPUTENTRY
_MODULEDESC.fields_by_name[
'sign2output'].message_type = _MODULEDESC_SIGN2OUTPUTENTRY
DESCRIPTOR.message_types_by_name['InputDesc'] = _INPUTDESC DESCRIPTOR.message_types_by_name['InputDesc'] = _INPUTDESC
DESCRIPTOR.message_types_by_name['OutputDesc'] = _OUTPUTDESC DESCRIPTOR.message_types_by_name['OutputDesc'] = _OUTPUTDESC
DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC
...@@ -236,13 +333,37 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType( ...@@ -236,13 +333,37 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType(
'ModuleDesc', 'ModuleDesc',
(_message.Message, ), (_message.Message, ),
dict( dict(
Sign2inputEntry=_reflection.GeneratedProtocolMessageType(
'Sign2inputEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEDESC_SIGN2INPUTENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2inputEntry)
)),
Sign2outputEntry=_reflection.GeneratedProtocolMessageType(
'Sign2outputEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEDESC_SIGN2OUTPUTENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2outputEntry)
)),
DESCRIPTOR=_MODULEDESC, DESCRIPTOR=_MODULEDESC,
__module__='module_desc_pb2' __module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc) # @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc)
)) ))
_sym_db.RegisterMessage(ModuleDesc) _sym_db.RegisterMessage(ModuleDesc)
_sym_db.RegisterMessage(ModuleDesc.Sign2inputEntry)
_sym_db.RegisterMessage(ModuleDesc.Sign2outputEntry)
DESCRIPTOR.has_options = True DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003')) _b('H\003'))
_MODULEDESC_SIGN2INPUTENTRY.has_options = True
_MODULEDESC_SIGN2INPUTENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
_MODULEDESC_SIGN2OUTPUTENTRY.has_options = True
_MODULEDESC_SIGN2OUTPUTENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册