提交 c9ac67d6 编写于 作者: W wuzewu

save param attr with pb format

上级 86470a5d
......@@ -80,11 +80,48 @@ class Module(object):
def _process_parameter(self):
global_block = self.inference_program.global_block()
param_path = ModuleConfig.meta_param_path(self.module_dir)
with open(param_path, "rb") as file:
param_arr = pickle.load(file)
for param in param_arr:
param['name'] = HUB_VAR_PREFIX + param['name']
param_attrs = self.config.desc.param_attrs
for key, param_attr in param_attrs.items():
param = {}
param['name'] = HUB_VAR_PREFIX + key
param['trainable'] = param_attr.trainable
param['do_model_average'] = param_attr.do_model_average
param['optimize_attr'] = {}
param['optimize_attr'][
'learning_rate'] = param_attr.optimize_attr.m['learning_rate'].f
# TODO(wuzewu): recover the param attr with a more reliable way
if param_attr.regularizer.type == "L2DecayRegularizer":
regularizer = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=param_attr.regularizer.
regularization_coeff)
elif param_attr.regularizer.type == "L1DecayRegularizer":
regularizer = fluid.regularizer.L1DecayRegularizer(
regularization_coeff=param_attr.regularizer.
regularization_coeff)
else:
regularizer = None
param['regularizer'] = regularizer
if param_attr.gradient_clip_attr.type == "ErrorClipByValue":
clip = fluid.clip.ErrorClipByValue(
max=param_attr.gradient_clip_attr.max,
min=param_attr.gradient_clip_attr.min)
elif param_attr.gradient_clip_attr.type == "GradientClipByValue":
clip = fluid.clip.GradientClipByValue(
max=param_attr.gradient_clip_attr.max,
min=param_attr.gradient_clip_attr.min)
elif param_attr.gradient_clip_attr.type == "GradientClipByNorm":
clip = fluid.clip.GradientClipByNorm(
clip_norm=param_attr.gradient_clip_attr.clip_norm)
elif param_attr.gradient_clip_attr.type == "GradientClipByGlobalNorm":
clip = fluid.clip.GradientClipByNorm(
clip_norm=param_attr.gradient_clip_attr.clip_norm,
group_name=param_attr.gradient_clip_attr.group_name)
else:
clip = None
param['gradient_clip_attr'] = clip
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
......@@ -309,20 +346,46 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fo.write("{}\t{}\n".format(w, w_id))
# save fluid Parameter
param_arr = []
param_attrs = module_desc.param_attrs
for param in program.global_block().iter_parameters():
param_info = {
'name': param.name,
'regularizer': param.regularizer,
'gradient_clip_attr': param.gradient_clip_attr,
'trainable': param.trainable,
'optimize_attr': param.optimize_attr,
'do_model_average': param.do_model_average
}
param_arr.append(param_info)
with open(ModuleConfig.meta_param_path(module_dir), "wb") as fo:
pickle.dump(param_arr, fo)
param_attr = param_attrs[param.name]
param_attr.trainable = param.trainable
if param.do_model_average:
param_attr.do_model_average = param.do_model_average
# TODO(wuzewu): add a func to transfer python dict to fexiable data
param_attr.optimize_attr.type = module_desc_pb2.MAP
param_attr.optimize_attr.m['learning_rate'].type = module_desc_pb2.FLOAT
param_attr.optimize_attr.m['learning_rate'].f = param.optimize_attr[
'learning_rate']
if param.regularizer:
if isinstance(param.regularizer,
fluid.regularizer.L2DecayRegularizer):
param_attr.regularizer.type = "L2DecayRegularizer"
if isinstance(param.regularizer,
fluid.regularizer.L1DecayRegularizer):
param_attr.regularizer.type = "L1DecayRegularizer"
param_attr.regularizer.regularization_coeff = param.regularizer.regularization_coeff
if param.gradient_clip_attr:
if isinstance(param.gradient_clip_attr,
fluid.clip.ErrorClipByValue):
param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max
param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min
param_attr.gradient_clip_attr.type = "ErrorClipByValue"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByValue):
param_attr.gradient_clip_attr.max = param.gradient_clip_attr.max
param_attr.gradient_clip_attr.min = param.gradient_clip_attr.min
param_attr.gradient_clip_attr.type = "GradientClipByValue"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByNorm):
param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm
param_attr.gradient_clip_attr.type = "GradientClipByNorm"
if isinstance(param.gradient_clip_attr,
fluid.clip.GradientClipByGlobalNorm):
param_attr.gradient_clip_attr.clip_norm = param.gradient_clip_attr.clip_norm
param_attr.gradient_clip_attr.group_name = param.gradient_clip_attr.group_name
param_attr.gradient_clip_attr.type = "GradientClipByGlobalNorm"
# save signarture info
sign_map = module_desc.sign2var
......
......@@ -18,6 +18,26 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub;
enum DataType {
INT = 0;
FLOAT = 1;
STRING = 2;
BOOLEAN = 3;
LIST = 4;
MAP = 5;
}
message FlexibleData {
DataType type = 1;
string name = 2;
int32 i = 3;
float f = 4;
bool b = 5;
string s = 6;
map<string, FlexibleData> m = 7;
map<int32, FlexibleData> l = 8;
}
// Feed Variable Description
message FeedDesc {
string var_name = 1;
......@@ -41,6 +61,27 @@ message AuthInfo {
string hub_version = 2;
}
message ParamAttr {
message Regularizer {
string type = 1;
float regularization_coeff = 2;
}
message GradientClipAttr {
string type = 1;
float min = 2;
float max = 3;
float clip_norm = 4;
string group_name = 5;
}
Regularizer regularizer = 1;
GradientClipAttr gradient_clip_attr = 2;
FlexibleData optimize_attr = 3;
bool trainable = 4;
bool do_model_average = 5;
}
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
......@@ -56,5 +97,7 @@ message ModuleDesc {
bool contain_assets = 4;
AuthInfo auth_info = 5;
map<string, ParamAttr> param_attrs = 6;
};
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册