提交 a96ef721 编写于 作者: W wuzewu

save more param info

上级 df4f8b00
...@@ -22,6 +22,7 @@ import paddle.fluid as fluid ...@@ -22,6 +22,7 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import tempfile import tempfile
import os import os
import pickle
from collections import defaultdict from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
...@@ -88,21 +89,22 @@ class Module(object): ...@@ -88,21 +89,22 @@ class Module(object):
def _process_parameter(self): def _process_parameter(self):
global_block = self.inference_program.global_block() global_block = self.inference_program.global_block()
for param in self.config.desc.parameters: filepath = os.path.join(self.module_dir, "param.pkl")
name = param.name with open(filepath, "rb") as file:
if name in global_block.vars: param_arr = pickle.load(file)
var = global_block.vars[name] for param in param_arr:
global_block.create_parameter( if (param['name'] not in global_block.vars):
name=name, continue
trainable=param.trainable, var = global_block.var(param['name'])
shape=var.shape, global_block.create_parameter(
dtype=var.dtype, **param,
optimize_attr={'learning_rate': param.learning_rate}, shape=var.shape,
type=var.type, dtype=var.dtype,
lod_level=var.lod_level, type=var.type,
error_clip=var.error_clip, lod_level=var.lod_level,
stop_gradient=var.stop_gradient, error_clip=var.error_clip,
is_data=var.is_data) stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _construct_feed_dict(self, inputs): def _construct_feed_dict(self, inputs):
""" Construct feed dict according to user's inputs and module config. """ Construct feed dict according to user's inputs and module config.
......
...@@ -23,6 +23,7 @@ from paddle_hub.signature import Signature ...@@ -23,6 +23,7 @@ from paddle_hub.signature import Signature
from paddle_hub.module import mkdir from paddle_hub.module import mkdir
import os import os
import pickle
def create_module(sign_arr, program, path=None, assets=None): def create_module(sign_arr, program, path=None, assets=None):
...@@ -46,11 +47,21 @@ def create_module(sign_arr, program, path=None, assets=None): ...@@ -46,11 +47,21 @@ def create_module(sign_arr, program, path=None, assets=None):
os.makedirs(os.path.join(path, "assets")) os.makedirs(os.path.join(path, "assets"))
# save fluid Parameter # save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters(): for param in program.global_block().iter_parameters():
parameter = module.parameters.add() param_info = {
parameter.name = param.name 'name': param.name,
parameter.learning_rate = param.optimize_attr["learning_rate"] 'regularizer': param.regularizer,
parameter.trainable = param.trainable 'gradient_clip_attr': param.gradient_clip_attr,
'trainable': param.trainable,
'optimize_attr': param.optimize_attr,
'do_model_average': param.do_model_average
}
param_arr.append(param_info)
pklname = os.path.join(path, "param.pkl")
with open(pklname, "wb") as file:
pickle.dump(param_arr, file)
# save signarture info # save signarture info
sign_map = module.sign2var sign_map = module.sign2var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册