提交 1ebde3bd 编写于 作者: W wuzewu

optimizer the method of serialize param attribute

上级 bfedfda8
......@@ -30,7 +30,8 @@ from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list, get_variable_info, mkdir
from paddle_hub.utils import to_list, mkdir
from paddle_hub.paddle_helper import from_param_to_flexible_data, get_variable_info, from_flexible_data_to_param
from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
......@@ -73,48 +74,10 @@ class Module(object):
def _process_parameter(self):
global_block = self.inference_program.global_block()
param_attrs = self.config.desc.param_attrs
for key, param_attr in param_attrs.items():
param = {}
param_attrs = self.config.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
param['trainable'] = param_attr.trainable
param['do_model_average'] = param_attr.do_model_average
param['optimize_attr'] = {}
param['optimize_attr'][
'learning_rate'] = param_attr.optimize_attr.m['learning_rate'].f
# TODO(wuzewu): recover the param attr with a more reliable way
if param_attr.regularizer.type == "L2DecayRegularizer":
regularizer = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=param_attr.regularizer.
regularization_coeff)
elif param_attr.regularizer.type == "L1DecayRegularizer":
regularizer = fluid.regularizer.L1DecayRegularizer(
regularization_coeff=param_attr.regularizer.
regularization_coeff)
else:
regularizer = None
param['regularizer'] = regularizer
if param_attr.gradient_clip_attr.type == "ErrorClipByValue":
clip = fluid.clip.ErrorClipByValue(
max=param_attr.gradient_clip_attr.max,
min=param_attr.gradient_clip_attr.min)
elif param_attr.gradient_clip_attr.type == "GradientClipByValue":
clip = fluid.clip.GradientClipByValue(
max=param_attr.gradient_clip_attr.max,
min=param_attr.gradient_clip_attr.min)
elif param_attr.gradient_clip_attr.type == "GradientClipByNorm":
clip = fluid.clip.GradientClipByNorm(
clip_norm=param_attr.gradient_clip_attr.clip_norm)
elif param_attr.gradient_clip_attr.type == "GradientClipByGlobalNorm":
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=param_attr.gradient_clip_attr.clip_norm,
group_name=param_attr.gradient_clip_attr.group_name)
else:
clip = None
param['gradient_clip_attr'] = clip
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
......@@ -341,46 +304,13 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
fo.write("{}\t{}\n".format(w, w_id))
# save fluid Parameter
param_attrs = module_desc.param_attrs
extra_info = module_desc.extra_info
extra_info.type = module_desc_pb2.MAP
param_attrs = extra_info.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in program.global_block().iter_parameters():
param_attr = param_attrs[param.name]
param_attr.trainable = param.trainable
if param.do_model_average:
param_attr.do_model_average = param.do_model_average
# TODO(wuzewu): add a func to transfer python dict to fexiable data
param_attr.optimize_attr.type = module_desc_pb2.MAP
param_attr.optimize_attr.m['learning_rate'].type = module_desc_pb2.FLOAT
param_attr.optimize_attr.m['learning_rate'].f = param.optimize_attr[
'learning_rate']
if param.regularizer:
if isinstance(param.regularizer,
fluid.regularizer.L2DecayRegularizer):
param_attr.regularizer.type = "L2DecayRegularizer"
if isinstance(param.regularizer,
fluid.regularizer.L1DecayRegularizer):
param_attr.regularizer.type = "L1DecayRegularizer"
param_attr.regularizer.regularization_coeff = param.regularizer.regularization_coeff
if param.gradient_clip_attr:
if isinstance(param.gradient_clip_attr,
fluid.clip.ErrorClipByValue):
param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max
param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min
param_attr.gradient_clip_attr.type = "ErrorClipByValue"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByValue):
param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max
param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min
param_attr.gradient_clip_attr.type = "GradientClipByValue"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByNorm):
param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm
param_attr.gradient_clip_attr.type = "GradientClipByNorm"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByGlobalNorm):
param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm
param_attr.gradient_clip_attr.group_name = param.gradient_clip_attr.group_name
param_attr.gradient_clip_attr.type = "GradientClipByGlobalNorm"
param_attr = param_attrs.map.data[param.name]
from_param_to_flexible_data(param, param_attr)
# save signarture info
sign_map = module_desc.sign2var
......
......@@ -19,23 +19,34 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub;
enum DataType {
INT = 0;
FLOAT = 1;
STRING = 2;
BOOLEAN = 3;
LIST = 4;
MAP = 5;
NONE = 0;
INT = 1;
FLOAT = 2;
STRING = 3;
BOOLEAN = 4;
LIST = 5;
MAP = 6;
SET = 7;
OBJECT = 8;
}
message KVData {
map<string, DataType> keyType = 1;
map<string, FlexibleData> data = 2;
}
message FlexibleData {
DataType type = 1;
string name = 2;
int32 i = 3;
int64 i = 3;
float f = 4;
bool b = 5;
string s = 6;
map<string, FlexibleData> m = 7;
map<int32, FlexibleData> l = 8;
KVData map = 7;
KVData list = 8;
KVData set = 9;
KVData object = 10;
string info = 11;
}
// Feed Variable Description
......@@ -61,27 +72,6 @@ message AuthInfo {
string hub_version = 2;
}
message ParamAttr {
message Regularizer {
string type = 1;
float regularization_coeff = 2;
}
message GradientClipAttr {
string type = 1;
float min = 2;
float max = 3;
float clip_norm = 4;
string group_name = 5;
}
Regularizer regularizer = 1;
GradientClipAttr gradient_clip_attr = 2;
FlexibleData optimize_attr = 3;
bool trainable = 4;
bool do_model_average = 5;
}
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
......@@ -98,6 +88,6 @@ message ModuleDesc {
AuthInfo auth_info = 5;
map<string, ParamAttr> param_attrs = 6;
FlexibleData extra_info = 6;
};
......@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub',
syntax='proto3',
serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\xcc\x02\n\x0c\x46lexibleData\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.paddle_hub.DataType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01\x62\x18\x05 \x01(\x08\x12\t\n\x01s\x18\x06 \x01(\t\x12*\n\x01m\x18\x07 \x03(\x0b\x32\x1f.paddle_hub.FlexibleData.MEntry\x12*\n\x01l\x18\x08 \x03(\x0b\x32\x1f.paddle_hub.FlexibleData.LEntry\x1a\x42\n\x06MEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\x1a\x42\n\x06LEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\x83\x03\n\tParamAttr\x12\x36\n\x0bregularizer\x18\x01 \x01(\x0b\x32!.paddle_hub.ParamAttr.Regularizer\x12\x42\n\x12gradient_clip_attr\x18\x02 \x01(\x0b\x32&.paddle_hub.ParamAttr.GradientClipAttr\x12/\n\roptimize_attr\x18\x03 \x01(\x0b\x32\x18.paddle_hub.FlexibleData\x12\x11\n\ttrainable\x18\x04 \x01(\x08\x12\x18\n\x10\x64o_model_average\x18\x05 \x01(\x08\x1a\x39\n\x0bRegularizer\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x1c\n\x14regularization_coeff\x18\x02 \x01(\x02\x1a\x61\n\x10GradientClipAttr\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03min\x18\x02 \x01(\x02\x12\x0b\n\x03max\x18\x03 \x01(\x02\x12\x11\n\tclip_norm\x18\x04 \x01(\x02\x12\x12\n\ngroup_name\x18\x05 \x01(\t\"\xf8\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\'\n\tauth_info\x18\x05 \x01(\x0b\x32\x14.paddle_hub.AuthInfo\x12;\n\x0bparam_attrs\x18\x06 \x03(\x0b\x32&.paddle_hub.ModuleDesc.ParamAttrsEntry\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x1aH\n\x0fParamAttrsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ParamAttr:\x02\x38\x01*J\n\x08\x44\x61taType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x0b\n\x07\x42OOLEAN\x10\x03\x12\x08\n\x04LIST\x10\x04\x12\x07\n\x03MAP\x10\x05\x42\x02H\x03\x62\x06proto3'
'\n\x11module_desc.proto\x12\npaddle_hub\"\xf3\x01\n\x06KVData\x12\x30\n\x07keyType\x18\x01 \x03(\x0b\x32\x1f.paddle_hub.KVData.KeyTypeEntry\x12*\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x1c.paddle_hub.KVData.DataEntry\x1a\x44\n\x0cKeyTypeEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0e\x32\x14.paddle_hub.DataType:\x02\x38\x01\x1a\x45\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\"\x82\x02\n\x0c\x46lexibleData\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.paddle_hub.DataType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01\x62\x18\x05 \x01(\x08\x12\t\n\x01s\x18\x06 \x01(\t\x12\x1f\n\x03map\x18\x07 \x01(\x0b\x32\x12.paddle_hub.KVData\x12 \n\x04list\x18\x08 \x01(\x0b\x32\x12.paddle_hub.KVData\x12\x1f\n\x03set\x18\t \x01(\x0b\x32\x12.paddle_hub.KVData\x12\"\n\x06object\x18\n \x01(\x0b\x32\x12.paddle_hub.KVData\x12\x0c\n\x04info\x18\x0b \x01(\t\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\x9f\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\'\n\tauth_info\x18\x05 \x01(\x0b\x32\x14.paddle_hub.AuthInfo\x12,\n\nextra_info\x18\x06 \x01(\x0b\x32\x18.paddle_hub.FlexibleData\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01*i\n\x08\x44\x61taType\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03INT\x10\x01\x12\t\n\x05\x46LOAT\x10\x02\x12\n\n\x06STRING\x10\x03\x12\x0b\n\x07\x42OOLEAN\x10\x04\x12\x08\n\x04LIST\x10\x05\x12\x07\n\x03MAP\x10\x06\x12\x07\n\x03SET\x10\x07\x12\n\n\x06OBJECT\x10\x08\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -29,43 +29,52 @@ _DATATYPE = _descriptor.EnumDescriptor(
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0, options=None, type=None),
name='NONE', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1, options=None, type=None),
name='INT', index=1, number=1, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2, options=None, type=None),
name='FLOAT', index=2, number=2, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=3, number=3, options=None, type=None),
name='STRING', index=3, number=3, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LIST', index=4, number=4, options=None, type=None),
name='BOOLEAN', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='MAP', index=5, number=5, options=None, type=None),
name='LIST', index=5, number=5, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='MAP', index=6, number=6, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='SET', index=7, number=7, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='OBJECT', index=8, number=8, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=1382,
serialized_end=1456,
serialized_start=1075,
serialized_end=1180,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
INT = 0
FLOAT = 1
STRING = 2
BOOLEAN = 3
LIST = 4
MAP = 5
NONE = 0
INT = 1
FLOAT = 2
STRING = 3
BOOLEAN = 4
LIST = 5
MAP = 6
SET = 7
OBJECT = 8
_FLEXIBLEDATA_MENTRY = _descriptor.Descriptor(
name='MEntry',
full_name='paddle_hub.FlexibleData.MEntry',
_KVDATA_KEYTYPEENTRY = _descriptor.Descriptor(
name='KeyTypeEntry',
full_name='paddle_hub.KVData.KeyTypeEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FlexibleData.MEntry.key',
full_name='paddle_hub.KVData.KeyTypeEntry.key',
index=0,
number=1,
type=9,
......@@ -81,14 +90,14 @@ _FLEXIBLEDATA_MENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.FlexibleData.MEntry.value',
full_name='paddle_hub.KVData.KeyTypeEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
type=14,
cpp_type=8,
label=1,
has_default_value=False,
default_value=None,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
......@@ -105,27 +114,27 @@ _FLEXIBLEDATA_MENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=232,
serialized_end=298,
serialized_start=138,
serialized_end=206,
)
_FLEXIBLEDATA_LENTRY = _descriptor.Descriptor(
name='LEntry',
full_name='paddle_hub.FlexibleData.LEntry',
_KVDATA_DATAENTRY = _descriptor.Descriptor(
name='DataEntry',
full_name='paddle_hub.KVData.DataEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FlexibleData.LEntry.key',
full_name='paddle_hub.KVData.DataEntry.key',
index=0,
number=1,
type=5,
cpp_type=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=0,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
......@@ -134,7 +143,7 @@ _FLEXIBLEDATA_LENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.FlexibleData.LEntry.value',
full_name='paddle_hub.KVData.DataEntry.value',
index=1,
number=2,
type=11,
......@@ -158,8 +167,63 @@ _FLEXIBLEDATA_LENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=300,
serialized_end=366,
serialized_start=208,
serialized_end=277,
)
_KVDATA = _descriptor.Descriptor(
name='KVData',
full_name='paddle_hub.KVData',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='keyType',
full_name='paddle_hub.KVData.keyType',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data',
full_name='paddle_hub.KVData.data',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_KVDATA_KEYTYPEENTRY,
_KVDATA_DATAENTRY,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=34,
serialized_end=277,
)
_FLEXIBLEDATA = _descriptor.Descriptor(
......@@ -206,8 +270,8 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
full_name='paddle_hub.FlexibleData.i',
index=2,
number=3,
type=5,
cpp_type=1,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
......@@ -266,15 +330,15 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='m',
full_name='paddle_hub.FlexibleData.m',
name='map',
full_name='paddle_hub.FlexibleData.map',
index=6,
number=7,
type=11,
cpp_type=10,
label=3,
label=1,
has_default_value=False,
default_value=[],
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
......@@ -282,15 +346,63 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l',
full_name='paddle_hub.FlexibleData.l',
name='list',
full_name='paddle_hub.FlexibleData.list',
index=7,
number=8,
type=11,
cpp_type=10,
label=3,
label=1,
has_default_value=False,
default_value=[],
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='set',
full_name='paddle_hub.FlexibleData.set',
index=8,
number=9,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='object',
full_name='paddle_hub.FlexibleData.object',
index=9,
number=10,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='info',
full_name='paddle_hub.FlexibleData.info',
index=10,
number=11,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
......@@ -299,18 +411,15 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
],
extensions=[],
nested_types=[
_FLEXIBLEDATA_MENTRY,
_FLEXIBLEDATA_LENTRY,
],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=34,
serialized_end=366,
serialized_start=280,
serialized_end=538,
)
_FEEDDESC = _descriptor.Descriptor(
......@@ -361,8 +470,8 @@ _FEEDDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=368,
serialized_end=411,
serialized_start=540,
serialized_end=583,
)
_FETCHDESC = _descriptor.Descriptor(
......@@ -413,8 +522,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=413,
serialized_end=457,
serialized_start=585,
serialized_end=629,
)
_MODULEVAR = _descriptor.Descriptor(
......@@ -465,8 +574,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=459,
serialized_end=554,
serialized_start=631,
serialized_end=726,
)
_AUTHINFO = _descriptor.Descriptor(
......@@ -517,263 +626,8 @@ _AUTHINFO = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=556,
serialized_end=611,
)
_PARAMATTR_REGULARIZER = _descriptor.Descriptor(
name='Regularizer',
full_name='paddle_hub.ParamAttr.Regularizer',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle_hub.ParamAttr.Regularizer.type',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='regularization_coeff',
full_name='paddle_hub.ParamAttr.Regularizer.regularization_coeff',
index=1,
number=2,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=845,
serialized_end=902,
)
_PARAMATTR_GRADIENTCLIPATTR = _descriptor.Descriptor(
name='GradientClipAttr',
full_name='paddle_hub.ParamAttr.GradientClipAttr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle_hub.ParamAttr.GradientClipAttr.type',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='min',
full_name='paddle_hub.ParamAttr.GradientClipAttr.min',
index=1,
number=2,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='max',
full_name='paddle_hub.ParamAttr.GradientClipAttr.max',
index=2,
number=3,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='clip_norm',
full_name='paddle_hub.ParamAttr.GradientClipAttr.clip_norm',
index=3,
number=4,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='group_name',
full_name='paddle_hub.ParamAttr.GradientClipAttr.group_name',
index=4,
number=5,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=904,
serialized_end=1001,
)
_PARAMATTR = _descriptor.Descriptor(
name='ParamAttr',
full_name='paddle_hub.ParamAttr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='regularizer',
full_name='paddle_hub.ParamAttr.regularizer',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='gradient_clip_attr',
full_name='paddle_hub.ParamAttr.gradient_clip_attr',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='optimize_attr',
full_name='paddle_hub.ParamAttr.optimize_attr',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trainable',
full_name='paddle_hub.ParamAttr.trainable',
index=3,
number=4,
type=8,
cpp_type=7,
label=1,
has_default_value=False,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='do_model_average',
full_name='paddle_hub.ParamAttr.do_model_average',
index=4,
number=5,
type=8,
cpp_type=7,
label=1,
has_default_value=False,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_PARAMATTR_REGULARIZER,
_PARAMATTR_GRADIENTCLIPATTR,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=614,
serialized_end=1001,
serialized_start=728,
serialized_end=783,
)
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
......@@ -825,61 +679,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=1236,
serialized_end=1306,
)
_MODULEDESC_PARAMATTRSENTRY = _descriptor.Descriptor(
name='ParamAttrsEntry',
full_name='paddle_hub.ModuleDesc.ParamAttrsEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.ModuleDesc.ParamAttrsEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.ModuleDesc.ParamAttrsEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(),
_b('8\001')),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=1308,
serialized_end=1380,
serialized_start=1003,
serialized_end=1073,
)
_MODULEDESC = _descriptor.Descriptor(
......@@ -970,15 +771,15 @@ _MODULEDESC = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='param_attrs',
full_name='paddle_hub.ModuleDesc.param_attrs',
name='extra_info',
full_name='paddle_hub.ModuleDesc.extra_info',
index=5,
number=6,
type=11,
cpp_type=10,
label=3,
label=1,
has_default_value=False,
default_value=[],
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
......@@ -989,7 +790,6 @@ _MODULEDESC = _descriptor.Descriptor(
extensions=[],
nested_types=[
_MODULEDESC_SIGN2VARENTRY,
_MODULEDESC_PARAMATTRSENTRY,
],
enum_types=[],
options=None,
......@@ -997,69 +797,74 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=1004,
serialized_end=1380,
serialized_start=786,
serialized_end=1073,
)
_FLEXIBLEDATA_MENTRY.fields_by_name['value'].message_type = _FLEXIBLEDATA
_FLEXIBLEDATA_MENTRY.containing_type = _FLEXIBLEDATA
_FLEXIBLEDATA_LENTRY.fields_by_name['value'].message_type = _FLEXIBLEDATA
_FLEXIBLEDATA_LENTRY.containing_type = _FLEXIBLEDATA
_KVDATA_KEYTYPEENTRY.fields_by_name['value'].enum_type = _DATATYPE
_KVDATA_KEYTYPEENTRY.containing_type = _KVDATA
_KVDATA_DATAENTRY.fields_by_name['value'].message_type = _FLEXIBLEDATA
_KVDATA_DATAENTRY.containing_type = _KVDATA
_KVDATA.fields_by_name['keyType'].message_type = _KVDATA_KEYTYPEENTRY
_KVDATA.fields_by_name['data'].message_type = _KVDATA_DATAENTRY
_FLEXIBLEDATA.fields_by_name['type'].enum_type = _DATATYPE
_FLEXIBLEDATA.fields_by_name['m'].message_type = _FLEXIBLEDATA_MENTRY
_FLEXIBLEDATA.fields_by_name['l'].message_type = _FLEXIBLEDATA_LENTRY
_FLEXIBLEDATA.fields_by_name['map'].message_type = _KVDATA
_FLEXIBLEDATA.fields_by_name['list'].message_type = _KVDATA
_FLEXIBLEDATA.fields_by_name['set'].message_type = _KVDATA
_FLEXIBLEDATA.fields_by_name['object'].message_type = _KVDATA
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
_MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_PARAMATTR_REGULARIZER.containing_type = _PARAMATTR
_PARAMATTR_GRADIENTCLIPATTR.containing_type = _PARAMATTR
_PARAMATTR.fields_by_name['regularizer'].message_type = _PARAMATTR_REGULARIZER
_PARAMATTR.fields_by_name[
'gradient_clip_attr'].message_type = _PARAMATTR_GRADIENTCLIPATTR
_PARAMATTR.fields_by_name['optimize_attr'].message_type = _FLEXIBLEDATA
_MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR
_MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC
_MODULEDESC_PARAMATTRSENTRY.fields_by_name['value'].message_type = _PARAMATTR
_MODULEDESC_PARAMATTRSENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY
_MODULEDESC.fields_by_name['auth_info'].message_type = _AUTHINFO
_MODULEDESC.fields_by_name[
'param_attrs'].message_type = _MODULEDESC_PARAMATTRSENTRY
_MODULEDESC.fields_by_name['extra_info'].message_type = _FLEXIBLEDATA
DESCRIPTOR.message_types_by_name['KVData'] = _KVDATA
DESCRIPTOR.message_types_by_name['FlexibleData'] = _FLEXIBLEDATA
DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC
DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC
DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR
DESCRIPTOR.message_types_by_name['AuthInfo'] = _AUTHINFO
DESCRIPTOR.message_types_by_name['ParamAttr'] = _PARAMATTR
DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
FlexibleData = _reflection.GeneratedProtocolMessageType(
'FlexibleData',
KVData = _reflection.GeneratedProtocolMessageType(
'KVData',
(_message.Message, ),
dict(
MEntry=_reflection.GeneratedProtocolMessageType(
'MEntry',
KeyTypeEntry=_reflection.GeneratedProtocolMessageType(
'KeyTypeEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FLEXIBLEDATA_MENTRY,
DESCRIPTOR=_KVDATA_KEYTYPEENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData.MEntry)
# @@protoc_insertion_point(class_scope:paddle_hub.KVData.KeyTypeEntry)
)),
LEntry=_reflection.GeneratedProtocolMessageType(
'LEntry',
DataEntry=_reflection.GeneratedProtocolMessageType(
'DataEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FLEXIBLEDATA_LENTRY,
DESCRIPTOR=_KVDATA_DATAENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData.LEntry)
# @@protoc_insertion_point(class_scope:paddle_hub.KVData.DataEntry)
)),
DESCRIPTOR=_KVDATA,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.KVData)
))
_sym_db.RegisterMessage(KVData)
_sym_db.RegisterMessage(KVData.KeyTypeEntry)
_sym_db.RegisterMessage(KVData.DataEntry)
FlexibleData = _reflection.GeneratedProtocolMessageType(
'FlexibleData',
(_message.Message, ),
dict(
DESCRIPTOR=_FLEXIBLEDATA,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData)
))
_sym_db.RegisterMessage(FlexibleData)
_sym_db.RegisterMessage(FlexibleData.MEntry)
_sym_db.RegisterMessage(FlexibleData.LEntry)
FeedDesc = _reflection.GeneratedProtocolMessageType(
'FeedDesc',
......@@ -1101,34 +906,6 @@ AuthInfo = _reflection.GeneratedProtocolMessageType(
))
_sym_db.RegisterMessage(AuthInfo)
ParamAttr = _reflection.GeneratedProtocolMessageType(
'ParamAttr',
(_message.Message, ),
dict(
Regularizer=_reflection.GeneratedProtocolMessageType(
'Regularizer',
(_message.Message, ),
dict(
DESCRIPTOR=_PARAMATTR_REGULARIZER,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr.Regularizer)
)),
GradientClipAttr=_reflection.GeneratedProtocolMessageType(
'GradientClipAttr',
(_message.Message, ),
dict(
DESCRIPTOR=_PARAMATTR_GRADIENTCLIPATTR,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr.GradientClipAttr)
)),
DESCRIPTOR=_PARAMATTR,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ParamAttr)
))
_sym_db.RegisterMessage(ParamAttr)
_sym_db.RegisterMessage(ParamAttr.Regularizer)
_sym_db.RegisterMessage(ParamAttr.GradientClipAttr)
ModuleDesc = _reflection.GeneratedProtocolMessageType(
'ModuleDesc',
(_message.Message, ),
......@@ -1141,35 +918,23 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType(
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2varEntry)
)),
ParamAttrsEntry=_reflection.GeneratedProtocolMessageType(
'ParamAttrsEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEDESC_PARAMATTRSENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.ParamAttrsEntry)
)),
DESCRIPTOR=_MODULEDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc)
))
_sym_db.RegisterMessage(ModuleDesc)
_sym_db.RegisterMessage(ModuleDesc.Sign2varEntry)
_sym_db.RegisterMessage(ModuleDesc.ParamAttrsEntry)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003'))
_FLEXIBLEDATA_MENTRY.has_options = True
_FLEXIBLEDATA_MENTRY._options = _descriptor._ParseOptions(
_KVDATA_KEYTYPEENTRY.has_options = True
_KVDATA_KEYTYPEENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
_FLEXIBLEDATA_LENTRY.has_options = True
_FLEXIBLEDATA_LENTRY._options = _descriptor._ParseOptions(
_KVDATA_DATAENTRY.has_options = True
_KVDATA_DATAENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
_MODULEDESC_SIGN2VARENTRY.has_options = True
_MODULEDESC_SIGN2VARENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
_MODULEDESC_PARAMATTRSENTRY.has_options = True
_MODULEDESC_PARAMATTRSENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub import module_desc_pb2
from paddle_hub.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
import paddle
import paddle.fluid as fluid
def get_variable_info(var):
assert isinstance(
var,
fluid.framework.Variable), "var should be a fluid.framework.Variable"
var_info = {
'type': var.type,
'name': var.name,
'dtype': var.dtype,
'lod_level': var.lod_level,
'shape': var.shape,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip
}
if isinstance(var, fluid.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
def from_param_to_flexible_data(param, flexible_data):
flexible_data.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(param.trainable,
flexible_data.map.data['trainable'])
from_pyobj_to_flexible_data(param.do_model_average,
flexible_data.map.data['do_model_average'])
from_pyobj_to_flexible_data(param.optimize_attr,
flexible_data.map.data['optimize_attr'])
from_pyobj_to_flexible_data(param.regularizer,
flexible_data.map.data['regularizer'])
from_pyobj_to_flexible_data(param.gradient_clip_attr,
flexible_data.map.data['gradient_clip_attr'])
def from_flexible_data_to_param(flexible_data):
param = {'gradient_clip_attr': None, 'regularizer': None}
param['trainable'] = from_flexible_data_to_pyobj(
flexible_data.map.data['trainable'])
param['do_model_average'] = from_flexible_data_to_pyobj(
flexible_data.map.data['do_model_average'])
param['optimize_attr'] = from_flexible_data_to_pyobj(
flexible_data.map.data['optimize_attr'])
if flexible_data.map.data['regularizer'].type != module_desc_pb2.NONE:
regularizer_type = flexible_data.map.data['regularizer'].name
regularization_coeff = flexible_data.map.data[
'regularizer'].object.data['_regularization_coeff '].f
param['regularizer'] = eval(
"fluid.regularizer.%s(regularization_coeff = %f)" %
(regularizer_type, regularization_coeff))
if flexible_data.map.data['regularizer'].type != module_desc_pb2.NONE:
clip_type = flexible_data.map.data['gradient_clip_attr'].name
if clip_type == "ErrorClipByValue" or clip_type == "GradientClipByValue":
max = flexible_data.map.data[
'regularizer'].name, flexible_data.map.data[
'gradient_clip_attr'].object.data['max'].f
min = flexible_data.map.data[
'regularizer'].name, flexible_data.map.data[
'gradient_clip_attr'].object.data['min'].f
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(max = %f, min = %f)" % (clip_type, max, min))
if clip_type == "GradientClipByNorm":
clip_norm = flexible_data.map.data[
'gradient_clip_attr'].object.data['clip_norm'].f
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(clip_norm = %f)" % (clip_type, clip_norm))
if clip_type == "GradientClipByGlobalNorm":
clip_norm = flexible_data.map.data[
'gradient_clip_attr'].object.data['clip_norm'].f
group_name = flexible_data.map.data[
'gradient_clip_attr'].object.data['group_name'].f
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(clip_norm = %f, group_name = %f)" %
(clip_type, clip_norm, group_name))
return param
......@@ -17,6 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
import paddle
import paddle.fluid as fluid
import os
......@@ -30,34 +32,95 @@ def to_list(input):
return input
def get_variable_info(var):
assert isinstance(
var,
fluid.framework.Variable), "var should be a fluid.framework.Variable"
var_info = {
'type': var.type,
'name': var.name,
'dtype': var.dtype,
'lod_level': var.lod_level,
'shape': var.shape,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip
}
if isinstance(var, fluid.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
def mkdir(path):
""" the same as the shell command mkdir -p "
"""
if not os.path.exists(path):
os.makedirs(path)
def get_keyed_type_of_pyobj(pyobj):
if isinstance(pyobj, bool):
return module_desc_pb2.BOOLEAN
elif isinstance(pyobj, int):
return module_desc_pb2.INT
elif isinstance(pyobj, str):
return module_desc_pb2.STRING
elif isinstance(pyobj, float):
return module_desc_pb2.FLOAT
return module_desc_pb2.STRING
def from_pyobj_to_flexible_data(pyobj, flexible_data):
if isinstance(pyobj, bool):
flexible_data.type = module_desc_pb2.BOOLEAN
flexible_data.b = pyobj
elif isinstance(pyobj, int):
flexible_data.type = module_desc_pb2.INT
flexible_data.i = pyobj
elif isinstance(pyobj, str):
flexible_data.type = module_desc_pb2.STRING
flexible_data.s = pyobj
elif isinstance(pyobj, float):
flexible_data.type = module_desc_pb2.FLOAT
flexible_data.f = pyobj
elif isinstance(pyobj, list) or isinstance(pyobj, tuple):
flexible_data.type = module_desc_pb2.LIST
for index, obj in enumerate(pyobj):
from_pyobj_to_flexible_data(obj,
flexible_data.list.data[str(index)])
elif isinstance(pyobj, set):
flexible_data.type = module_desc_pb2.SET
for index, obj in enumerate(list(pyobj)):
from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)])
elif isinstance(pyobj, dict):
flexible_data.type = module_desc_pb2.MAP
for key, value in pyobj.items():
from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)])
flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key)
elif isinstance(pyobj, type(None)):
flexible_data.type = module_desc_pb2.NONE
else:
flexible_data.type = module_desc_pb2.OBJECT
flexible_data.name = str(pyobj.__class__.__name__)
for key, value in pyobj.__dict__.items():
from_pyobj_to_flexible_data(value,
flexible_data.object.data[str(key)])
flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj(
key)
def from_flexible_data_to_pyobj(flexible_data):
if flexible_data.type == module_desc_pb2.BOOLEAN:
result = flexible_data.b
elif flexible_data.type == module_desc_pb2.INT:
result = flexible_data.i
elif flexible_data.type == module_desc_pb2.STRING:
result = flexible_data.s
elif flexible_data.type == module_desc_pb2.FLOAT:
result = flexible_data.f
elif flexible_data.type == module_desc_pb2.LIST:
result = []
for index in range(len(flexible_data.list.data)):
result.append(
from_flexible_data_to_pyobj(flexible_data.m.data(str(index))))
elif flexible_data.type == module_desc_pb2.SET:
result = set()
for index in range(len(flexible_data.set.data)):
result.add(
from_flexible_data_to_pyobj(flexible_data.m.data(str(index))))
elif flexible_data.type == module_desc_pb2.MAP:
result = {}
for key, value in flexible_data.map.data.items():
key = flexible_data.map.keyType[key]
result[key] = from_flexible_data_to_pyobj(value)
elif flexible_data.type == module_desc_pb2.NONE:
result = None
elif flexible_data.type == module_desc_pb2.OBJECT:
result = None
logger.warning("can't tran flexible_data to python object")
else:
result = None
logger.warning("unknown type of flexible_data")
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册