提交 96102e2f 编写于 作者: W wuzewu

save unique name generator and reload it when creating module

上级 7e3c5dda
......@@ -27,6 +27,7 @@ import pickle
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.config import RunConfig, ParamTrainConfig
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
DICT_NAME = "dict.txt"
......@@ -86,6 +87,13 @@ class Module(object):
self.config = ModuleConfig(self.module_dir)
self.config.load()
self._process_parameter()
#TODO(wuzewu): recover the default unique name generator someother where
self._process_uqn()
def _process_uqn(self):
filepath = os.path.join(self.module_dir, "uqn.pkl")
with open(filepath, "rb") as file:
fluid.unique_name.switch(pickle.load(file))
def _process_parameter(self):
global_block = self.inference_program.global_block()
......@@ -116,27 +124,25 @@ class Module(object):
return feed_dict
def __call__(self, inputs=None, sign_name="default"):
def __call__(self, sign_name="default", run_config=None):
""" Call default signature and return results
"""
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict = self._construct_feed_dict(inputs)
print("feed_dict", feed_dict)
ret_numpy = self.config.return_numpy()
print("ret_numpy", ret_numpy)
results = self.exe.run(
self.inference_program,
#feed={self.feed_target_names[0]: word_ids_lod_tensor},
feed=feed_dict,
fetch_list=self.fetch_targets,
return_numpy=ret_numpy)
print("module fetch_target_names", self.feed_target_names)
print("module fetch_targets", self.fetch_targets)
np_result = np.array(results[0])
return np_result
def _set_param_trainable(program, trainable=False):
for param in program.global_block().iter_parameters():
param.trainable = trainable
if not run_config:
run_config = RunConfig()
program = self.get_inference_program().clone()
if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
_set_param_trainable(program=program, trainable=True)
elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
_set_param_trainable(program=program, trainable=False)
return self.feed_target_names, self.fetch_targets, program
def get_vars(self):
"""
......
......@@ -46,6 +46,12 @@ def create_module(sign_arr, program, path=None, assets=None):
module.contain_assets = True
os.makedirs(os.path.join(path, "assets"))
# save the unique name object
generator = fluid.unique_name.generator
pklname = os.path.join(path, "uqn.pkl")
with open(pklname, "wb") as file:
pickle.dump(generator, file)
# save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册