提交 a96ef721 编写于 作者: W wuzewu

save more param info

上级 df4f8b00
......@@ -22,6 +22,7 @@ import paddle.fluid as fluid
import numpy as np
import tempfile
import os
import pickle
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
......@@ -88,21 +89,22 @@ class Module(object):
def _process_parameter(self):
global_block = self.inference_program.global_block()
for param in self.config.desc.parameters:
name = param.name
if name in global_block.vars:
var = global_block.vars[name]
global_block.create_parameter(
name=name,
trainable=param.trainable,
shape=var.shape,
dtype=var.dtype,
optimize_attr={'learning_rate': param.learning_rate},
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
filepath = os.path.join(self.module_dir, "param.pkl")
with open(filepath, "rb") as file:
param_arr = pickle.load(file)
for param in param_arr:
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
global_block.create_parameter(
**param,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _construct_feed_dict(self, inputs):
""" Construct feed dict according to user's inputs and module config.
......
......@@ -23,6 +23,7 @@ from paddle_hub.signature import Signature
from paddle_hub.module import mkdir
import os
import pickle
def create_module(sign_arr, program, path=None, assets=None):
......@@ -46,11 +47,21 @@ def create_module(sign_arr, program, path=None, assets=None):
os.makedirs(os.path.join(path, "assets"))
# save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters():
parameter = module.parameters.add()
parameter.name = param.name
parameter.learning_rate = param.optimize_attr["learning_rate"]
parameter.trainable = param.trainable
param_info = {
'name': param.name,
'regularizer': param.regularizer,
'gradient_clip_attr': param.gradient_clip_attr,
'trainable': param.trainable,
'optimize_attr': param.optimize_attr,
'do_model_average': param.do_model_average
}
param_arr.append(param_info)
pklname = os.path.join(path, "param.pkl")
with open(pklname, "wb") as file:
pickle.dump(param_arr, file)
# save signarture info
sign_map = module.sign2var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册