From c9ac67d6c62d3ea78addb2bdeb8a602c7731ee7e Mon Sep 17 00:00:00 2001 From: wuzewu Date: Mon, 21 Jan 2019 19:39:43 +0800 Subject: [PATCH] save param attr with pb format --- paddle_hub/module.py | 99 ++++- paddle_hub/module_desc.proto | 43 ++ paddle_hub/module_desc_pb2.py | 737 +++++++++++++++++++++++++++++++++- 3 files changed, 848 insertions(+), 31 deletions(-) diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 12290b1e..95bb75e8 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -80,11 +80,48 @@ class Module(object): def _process_parameter(self): global_block = self.inference_program.global_block() - param_path = ModuleConfig.meta_param_path(self.module_dir) - with open(param_path, "rb") as file: - param_arr = pickle.load(file) - for param in param_arr: - param['name'] = HUB_VAR_PREFIX + param['name'] + param_attrs = self.config.desc.param_attrs + for key, param_attr in param_attrs.items(): + param = {} + param['name'] = HUB_VAR_PREFIX + key + param['trainable'] = param_attr.trainable + param['do_model_average'] = param_attr.do_model_average + param['optimize_attr'] = {} + param['optimize_attr'][ + 'learning_rate'] = param_attr.optimize_attr.m['learning_rate'].f + + # TODO(wuzewu): recover the param attr with a more reliable way + if param_attr.regularizer.type == "L2DecayRegularizer": + regularizer = fluid.regularizer.L2DecayRegularizer( + regularization_coeff=param_attr.regularizer. + regularization_coeff) + elif param_attr.regularizer.type == "L1DecayRegularizer": + regularizer = fluid.regularizer.L1DecayRegularizer( + regularization_coeff=param_attr.regularizer. + regularization_coeff) + else: + regularizer = None + param['regularizer'] = regularizer + + if param_attr.gradient_clip_attr.type == "ErrorClipByValue": + clip = fluid.clip.ErrorClipByValue( + max=param_attr.gradient_clip_attr.max, + min=param_attr.gradient_clip_attr.min) + elif param_attr.gradient_clip_attr.type == "GradientClipByValue": + clip = fluid.clip.GradientClipByValue( + max=param_attr.gradient_clip_attr.max, + min=param_attr.gradient_clip_attr.min) + elif param_attr.gradient_clip_attr.type == "GradientClipByNorm": + clip = fluid.clip.GradientClipByNorm( + clip_norm=param_attr.gradient_clip_attr.clip_norm) + elif param_attr.gradient_clip_attr.type == "GradientClipByGlobalNorm": + clip = fluid.clip.GradientClipByNorm( + clip_norm=param_attr.gradient_clip_attr.clip_norm, + group_name=param_attr.gradient_clip_attr.group_name) + else: + clip = None + param['gradient_clip_attr'] = clip + if (param['name'] not in global_block.vars): continue var = global_block.var(param['name']) @@ -309,20 +346,46 @@ def create_module(sign_arr, module_dir=None, word_dict=None): fo.write("{}\t{}\n".format(w, w_id)) # save fluid Parameter - param_arr = [] + param_attrs = module_desc.param_attrs for param in program.global_block().iter_parameters(): - param_info = { - 'name': param.name, - 'regularizer': param.regularizer, - 'gradient_clip_attr': param.gradient_clip_attr, - 'trainable': param.trainable, - 'optimize_attr': param.optimize_attr, - 'do_model_average': param.do_model_average - } - param_arr.append(param_info) - - with open(ModuleConfig.meta_param_path(module_dir), "wb") as fo: - pickle.dump(param_arr, fo) + param_attr = param_attrs[param.name] + param_attr.trainable = param.trainable + if param.do_model_average: + param_attr.do_model_average = param.do_model_average + # TODO(wuzewu): add a func to transfer python dict to fexiable data + param_attr.optimize_attr.type = module_desc_pb2.MAP + param_attr.optimize_attr.m['learning_rate'].type = module_desc_pb2.FLOAT + param_attr.optimize_attr.m['learning_rate'].f = param.optimize_attr[ + 'learning_rate'] + if param.regularizer: + if isinstance(param.regularizer, + fluid.regularizer.L2DecayRegularizer): + param_attr.regularizer.type = "L2DecayRegularizer" + if isinstance(param.regularizer, + fluid.regularizer.L1DecayRegularizer): + param_attr.regularizer.type = "L1DecayRegularizer" + param_attr.regularizer.regularization_coeff = param.regularizer.regularization_coeff + + if param.gradient_clip_attr: + if isinstance(param.gradient_clip_attr, + fluid.clip.ErrorClipByValue): + param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max + param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min + param_attr.gradient_clip_attr.type = "ErrorClipByValue" + if isinstance(param.gradient_clip_attr, + fluid.clip.GradientClipByValue): + param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max + param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min + param_attr.gradient_clip_attr.type = "GradientClipByValue" + if isinstance(param.gradient_clip_attr, + fluid.clip.GradientClipByNorm): + param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm + param_attr.gradient_clip_attr.type = "GradientClipByNorm" + if isinstance(param.gradient_clip_attr, + fluid.clip.GradientClipByGlobalNorm): + param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm + param_attr.gradient_clip_attr.group_name = param.gradient_clip_attr.group_name + param_attr.gradient_clip_attr.type = "GradientClipByGlobalNorm" # save signarture info sign_map = module_desc.sign2var diff --git a/paddle_hub/module_desc.proto b/paddle_hub/module_desc.proto index dd5d9a19..174d9875 100644 --- a/paddle_hub/module_desc.proto +++ b/paddle_hub/module_desc.proto @@ -18,6 +18,26 @@ option optimize_for = LITE_RUNTIME; package paddle_hub; +enum DataType { + INT = 0; + FLOAT = 1; + STRING = 2; + BOOLEAN = 3; + LIST = 4; + MAP = 5; +} + +message FlexibleData { + DataType type = 1; + string name = 2; + int32 i = 3; + float f = 4; + bool b = 5; + string s = 6; + map m = 7; + map l = 8; +} + // Feed Variable Description message FeedDesc { string var_name = 1; @@ -41,6 +61,27 @@ message AuthInfo { string hub_version = 2; } +message ParamAttr { + message Regularizer { + string type = 1; + float regularization_coeff = 2; + } + + message GradientClipAttr { + string type = 1; + float min = 2; + float max = 3; + float clip_norm = 4; + string group_name = 5; + } + + Regularizer regularizer = 1; + GradientClipAttr gradient_clip_attr = 2; + FlexibleData optimize_attr = 3; + bool trainable = 4; + bool do_model_average = 5; +} + // A Hub Module is stored in a directory with a file 'paddlehub.pb' // containing a serialized protocol message of this type. The further contents // of the directory depend on the storage format described by the message. @@ -56,5 +97,7 @@ message ModuleDesc { bool contain_assets = 4; AuthInfo auth_info = 5; + + map param_attrs = 6; }; diff --git a/paddle_hub/module_desc_pb2.py b/paddle_hub/module_desc_pb2.py index 8b5bd301..484b9306 100644 --- a/paddle_hub/module_desc_pb2.py +++ b/paddle_hub/module_desc_pb2.py @@ -3,6 +3,7 @@ import sys _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -17,10 +18,301 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle_hub', syntax='proto3', serialized_pb=_b( - '\n\x11module_desc.proto\x12\npaddle_hub\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\xf1\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\'\n\tauth_info\x18\x05 \x01(\x0b\x32\x14.paddle_hub.AuthInfo\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' + '\n\x11module_desc.proto\x12\npaddle_hub\"\xcc\x02\n\x0c\x46lexibleData\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.paddle_hub.DataType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01\x62\x18\x05 \x01(\x08\x12\t\n\x01s\x18\x06 \x01(\t\x12*\n\x01m\x18\x07 \x03(\x0b\x32\x1f.paddle_hub.FlexibleData.MEntry\x12*\n\x01l\x18\x08 \x03(\x0b\x32\x1f.paddle_hub.FlexibleData.LEntry\x1a\x42\n\x06MEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\x1a\x42\n\x06LEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\x83\x03\n\tParamAttr\x12\x36\n\x0bregularizer\x18\x01 \x01(\x0b\x32!.paddle_hub.ParamAttr.Regularizer\x12\x42\n\x12gradient_clip_attr\x18\x02 \x01(\x0b\x32&.paddle_hub.ParamAttr.GradientClipAttr\x12/\n\roptimize_attr\x18\x03 \x01(\x0b\x32\x18.paddle_hub.FlexibleData\x12\x11\n\ttrainable\x18\x04 \x01(\x08\x12\x18\n\x10\x64o_model_average\x18\x05 \x01(\x08\x1a\x39\n\x0bRegularizer\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x1c\n\x14regularization_coeff\x18\x02 \x01(\x02\x1a\x61\n\x10GradientClipAttr\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03min\x18\x02 \x01(\x02\x12\x0b\n\x03max\x18\x03 \x01(\x02\x12\x11\n\tclip_norm\x18\x04 \x01(\x02\x12\x12\n\ngroup_name\x18\x05 \x01(\t\"\xf8\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\'\n\tauth_info\x18\x05 \x01(\x0b\x32\x14.paddle_hub.AuthInfo\x12;\n\x0bparam_attrs\x18\x06 \x03(\x0b\x32&.paddle_hub.ModuleDesc.ParamAttrsEntry\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x1aH\n\x0fParamAttrsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ParamAttr:\x02\x38\x01*J\n\x08\x44\x61taType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x0b\n\x07\x42OOLEAN\x10\x03\x12\x08\n\x04LIST\x10\x04\x12\x07\n\x03MAP\x10\x05\x42\x02H\x03\x62\x06proto3' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) +_DATATYPE = _descriptor.EnumDescriptor( + name='DataType', + full_name='paddle_hub.DataType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='INT', index=0, number=0, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='FLOAT', index=1, number=1, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='STRING', index=2, number=2, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='BOOLEAN', index=3, number=3, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='LIST', index=4, number=4, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='MAP', index=5, number=5, options=None, type=None), + ], + containing_type=None, + options=None, + serialized_start=1382, + serialized_end=1456, +) +_sym_db.RegisterEnumDescriptor(_DATATYPE) + +DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) +INT = 0 +FLOAT = 1 +STRING = 2 +BOOLEAN = 3 +LIST = 4 +MAP = 5 + +_FLEXIBLEDATA_MENTRY = _descriptor.Descriptor( + name='MEntry', + full_name='paddle_hub.FlexibleData.MEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='paddle_hub.FlexibleData.MEntry.key', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='value', + full_name='paddle_hub.FlexibleData.MEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), + _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=232, + serialized_end=298, +) + +_FLEXIBLEDATA_LENTRY = _descriptor.Descriptor( + name='LEntry', + full_name='paddle_hub.FlexibleData.LEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='paddle_hub.FlexibleData.LEntry.key', + index=0, + number=1, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='value', + full_name='paddle_hub.FlexibleData.LEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), + _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=300, + serialized_end=366, +) + +_FLEXIBLEDATA = _descriptor.Descriptor( + name='FlexibleData', + full_name='paddle_hub.FlexibleData', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', + full_name='paddle_hub.FlexibleData.type', + index=0, + number=1, + type=14, + cpp_type=8, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='name', + full_name='paddle_hub.FlexibleData.name', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='i', + full_name='paddle_hub.FlexibleData.i', + index=2, + number=3, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='f', + full_name='paddle_hub.FlexibleData.f', + index=3, + number=4, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='b', + full_name='paddle_hub.FlexibleData.b', + index=4, + number=5, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='s', + full_name='paddle_hub.FlexibleData.s', + index=5, + number=6, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='m', + full_name='paddle_hub.FlexibleData.m', + index=6, + number=7, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='l', + full_name='paddle_hub.FlexibleData.l', + index=7, + number=8, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[ + _FLEXIBLEDATA_MENTRY, + _FLEXIBLEDATA_LENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=34, + serialized_end=366, +) + _FEEDDESC = _descriptor.Descriptor( name='FeedDesc', full_name='paddle_hub.FeedDesc', @@ -69,8 +361,8 @@ _FEEDDESC = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=33, - serialized_end=76, + serialized_start=368, + serialized_end=411, ) _FETCHDESC = _descriptor.Descriptor( @@ -121,8 +413,8 @@ _FETCHDESC = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=78, - serialized_end=122, + serialized_start=413, + serialized_end=457, ) _MODULEVAR = _descriptor.Descriptor( @@ -173,8 +465,8 @@ _MODULEVAR = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=124, - serialized_end=219, + serialized_start=459, + serialized_end=554, ) _AUTHINFO = _descriptor.Descriptor( @@ -225,8 +517,263 @@ _AUTHINFO = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=221, - serialized_end=276, + serialized_start=556, + serialized_end=611, +) + +_PARAMATTR_REGULARIZER = _descriptor.Descriptor( + name='Regularizer', + full_name='paddle_hub.ParamAttr.Regularizer', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', + full_name='paddle_hub.ParamAttr.Regularizer.type', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='regularization_coeff', + full_name='paddle_hub.ParamAttr.Regularizer.regularization_coeff', + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=845, + serialized_end=902, +) + +_PARAMATTR_GRADIENTCLIPATTR = _descriptor.Descriptor( + name='GradientClipAttr', + full_name='paddle_hub.ParamAttr.GradientClipAttr', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', + full_name='paddle_hub.ParamAttr.GradientClipAttr.type', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='min', + full_name='paddle_hub.ParamAttr.GradientClipAttr.min', + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='max', + full_name='paddle_hub.ParamAttr.GradientClipAttr.max', + index=2, + number=3, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='clip_norm', + full_name='paddle_hub.ParamAttr.GradientClipAttr.clip_norm', + index=3, + number=4, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='group_name', + full_name='paddle_hub.ParamAttr.GradientClipAttr.group_name', + index=4, + number=5, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=904, + serialized_end=1001, +) + +_PARAMATTR = _descriptor.Descriptor( + name='ParamAttr', + full_name='paddle_hub.ParamAttr', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='regularizer', + full_name='paddle_hub.ParamAttr.regularizer', + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='gradient_clip_attr', + full_name='paddle_hub.ParamAttr.gradient_clip_attr', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='optimize_attr', + full_name='paddle_hub.ParamAttr.optimize_attr', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='trainable', + full_name='paddle_hub.ParamAttr.trainable', + index=3, + number=4, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='do_model_average', + full_name='paddle_hub.ParamAttr.do_model_average', + index=4, + number=5, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[ + _PARAMATTR_REGULARIZER, + _PARAMATTR_GRADIENTCLIPATTR, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=614, + serialized_end=1001, ) _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( @@ -278,8 +825,61 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=450, - serialized_end=520, + serialized_start=1236, + serialized_end=1306, +) + +_MODULEDESC_PARAMATTRSENTRY = _descriptor.Descriptor( + name='ParamAttrsEntry', + full_name='paddle_hub.ModuleDesc.ParamAttrsEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', + full_name='paddle_hub.ModuleDesc.ParamAttrsEntry.key', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='value', + full_name='paddle_hub.ModuleDesc.ParamAttrsEntry.value', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), + _b('8\001')), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=1308, + serialized_end=1380, ) _MODULEDESC = _descriptor.Descriptor( @@ -369,10 +969,27 @@ _MODULEDESC = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='param_attrs', + full_name='paddle_hub.ModuleDesc.param_attrs', + index=5, + number=6, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[ _MODULEDESC_SIGN2VARENTRY, + _MODULEDESC_PARAMATTRSENTRY, ], enum_types=[], options=None, @@ -380,21 +997,69 @@ _MODULEDESC = _descriptor.Descriptor( syntax='proto3', extension_ranges=[], oneofs=[], - serialized_start=279, - serialized_end=520, + serialized_start=1004, + serialized_end=1380, ) +_FLEXIBLEDATA_MENTRY.fields_by_name['value'].message_type = _FLEXIBLEDATA +_FLEXIBLEDATA_MENTRY.containing_type = _FLEXIBLEDATA +_FLEXIBLEDATA_LENTRY.fields_by_name['value'].message_type = _FLEXIBLEDATA +_FLEXIBLEDATA_LENTRY.containing_type = _FLEXIBLEDATA +_FLEXIBLEDATA.fields_by_name['type'].enum_type = _DATATYPE +_FLEXIBLEDATA.fields_by_name['m'].message_type = _FLEXIBLEDATA_MENTRY +_FLEXIBLEDATA.fields_by_name['l'].message_type = _FLEXIBLEDATA_LENTRY _MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC +_PARAMATTR_REGULARIZER.containing_type = _PARAMATTR +_PARAMATTR_GRADIENTCLIPATTR.containing_type = _PARAMATTR +_PARAMATTR.fields_by_name['regularizer'].message_type = _PARAMATTR_REGULARIZER +_PARAMATTR.fields_by_name[ + 'gradient_clip_attr'].message_type = _PARAMATTR_GRADIENTCLIPATTR +_PARAMATTR.fields_by_name['optimize_attr'].message_type = _FLEXIBLEDATA _MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR _MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC +_MODULEDESC_PARAMATTRSENTRY.fields_by_name['value'].message_type = _PARAMATTR +_MODULEDESC_PARAMATTRSENTRY.containing_type = _MODULEDESC _MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY _MODULEDESC.fields_by_name['auth_info'].message_type = _AUTHINFO +_MODULEDESC.fields_by_name[ + 'param_attrs'].message_type = _MODULEDESC_PARAMATTRSENTRY +DESCRIPTOR.message_types_by_name['FlexibleData'] = _FLEXIBLEDATA DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR DESCRIPTOR.message_types_by_name['AuthInfo'] = _AUTHINFO +DESCRIPTOR.message_types_by_name['ParamAttr'] = _PARAMATTR DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC +DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE + +FlexibleData = _reflection.GeneratedProtocolMessageType( + 'FlexibleData', + (_message.Message, ), + dict( + MEntry=_reflection.GeneratedProtocolMessageType( + 'MEntry', + (_message.Message, ), + dict( + DESCRIPTOR=_FLEXIBLEDATA_MENTRY, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData.MEntry) + )), + LEntry=_reflection.GeneratedProtocolMessageType( + 'LEntry', + (_message.Message, ), + dict( + DESCRIPTOR=_FLEXIBLEDATA_LENTRY, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData.LEntry) + )), + DESCRIPTOR=_FLEXIBLEDATA, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData) + )) +_sym_db.RegisterMessage(FlexibleData) +_sym_db.RegisterMessage(FlexibleData.MEntry) +_sym_db.RegisterMessage(FlexibleData.LEntry) FeedDesc = _reflection.GeneratedProtocolMessageType( 'FeedDesc', @@ -436,6 +1101,34 @@ AuthInfo = _reflection.GeneratedProtocolMessageType( )) _sym_db.RegisterMessage(AuthInfo) +ParamAttr = _reflection.GeneratedProtocolMessageType( + 'ParamAttr', + (_message.Message, ), + dict( + Regularizer=_reflection.GeneratedProtocolMessageType( + 'Regularizer', + (_message.Message, ), + dict( + DESCRIPTOR=_PARAMATTR_REGULARIZER, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr.Regularizer) + )), + GradientClipAttr=_reflection.GeneratedProtocolMessageType( + 'GradientClipAttr', + (_message.Message, ), + dict( + DESCRIPTOR=_PARAMATTR_GRADIENTCLIPATTR, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr.GradientClipAttr) + )), + DESCRIPTOR=_PARAMATTR, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr) + )) +_sym_db.RegisterMessage(ParamAttr) +_sym_db.RegisterMessage(ParamAttr.Regularizer) +_sym_db.RegisterMessage(ParamAttr.GradientClipAttr) + ModuleDesc = _reflection.GeneratedProtocolMessageType( 'ModuleDesc', (_message.Message, ), @@ -448,17 +1141,35 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType( __module__='module_desc_pb2' # @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2varEntry) )), + ParamAttrsEntry=_reflection.GeneratedProtocolMessageType( + 'ParamAttrsEntry', + (_message.Message, ), + dict( + DESCRIPTOR=_MODULEDESC_PARAMATTRSENTRY, + __module__='module_desc_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.ParamAttrsEntry) + )), DESCRIPTOR=_MODULEDESC, __module__='module_desc_pb2' # @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc) )) _sym_db.RegisterMessage(ModuleDesc) _sym_db.RegisterMessage(ModuleDesc.Sign2varEntry) +_sym_db.RegisterMessage(ModuleDesc.ParamAttrsEntry) DESCRIPTOR.has_options = True DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003')) +_FLEXIBLEDATA_MENTRY.has_options = True +_FLEXIBLEDATA_MENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b('8\001')) +_FLEXIBLEDATA_LENTRY.has_options = True +_FLEXIBLEDATA_LENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b('8\001')) _MODULEDESC_SIGN2VARENTRY.has_options = True _MODULEDESC_SIGN2VARENTRY._options = _descriptor._ParseOptions( descriptor_pb2.MessageOptions(), _b('8\001')) +_MODULEDESC_PARAMATTRSENTRY.has_options = True +_MODULEDESC_PARAMATTRSENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b('8\001')) # @@protoc_insertion_point(module_scope) -- GitLab