提交 96102e2f 编写于 作者: W wuzewu

save unique name generator and reload it when creating module

上级 7e3c5dda
...@@ -27,6 +27,7 @@ import pickle ...@@ -27,6 +27,7 @@ import pickle
from collections import defaultdict from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2 from paddle_hub import module_desc_pb2
from paddle_hub.config import RunConfig, ParamTrainConfig
__all__ = ["Module", "ModuleConfig", "ModuleUtils"] __all__ = ["Module", "ModuleConfig", "ModuleUtils"]
DICT_NAME = "dict.txt" DICT_NAME = "dict.txt"
...@@ -86,6 +87,13 @@ class Module(object): ...@@ -86,6 +87,13 @@ class Module(object):
self.config = ModuleConfig(self.module_dir) self.config = ModuleConfig(self.module_dir)
self.config.load() self.config.load()
self._process_parameter() self._process_parameter()
#TODO(wuzewu): recover the default unique name generator someother where
self._process_uqn()
def _process_uqn(self):
filepath = os.path.join(self.module_dir, "uqn.pkl")
with open(filepath, "rb") as file:
fluid.unique_name.switch(pickle.load(file))
def _process_parameter(self): def _process_parameter(self):
global_block = self.inference_program.global_block() global_block = self.inference_program.global_block()
...@@ -116,27 +124,25 @@ class Module(object): ...@@ -116,27 +124,25 @@ class Module(object):
return feed_dict return feed_dict
def __call__(self, inputs=None, sign_name="default"): def __call__(self, sign_name="default", run_config=None):
""" Call default signature and return results """ Call default signature and return results
""" """
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict = self._construct_feed_dict(inputs) def _set_param_trainable(program, trainable=False):
print("feed_dict", feed_dict) for param in program.global_block().iter_parameters():
param.trainable = trainable
ret_numpy = self.config.return_numpy()
print("ret_numpy", ret_numpy) if not run_config:
results = self.exe.run( run_config = RunConfig()
self.inference_program,
#feed={self.feed_target_names[0]: word_ids_lod_tensor}, program = self.get_inference_program().clone()
feed=feed_dict,
fetch_list=self.fetch_targets, if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
return_numpy=ret_numpy) _set_param_trainable(program=program, trainable=True)
elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
print("module fetch_target_names", self.feed_target_names) _set_param_trainable(program=program, trainable=False)
print("module fetch_targets", self.fetch_targets)
np_result = np.array(results[0]) return self.feed_target_names, self.fetch_targets, program
return np_result
def get_vars(self): def get_vars(self):
""" """
......
...@@ -46,6 +46,12 @@ def create_module(sign_arr, program, path=None, assets=None): ...@@ -46,6 +46,12 @@ def create_module(sign_arr, program, path=None, assets=None):
module.contain_assets = True module.contain_assets = True
os.makedirs(os.path.join(path, "assets")) os.makedirs(os.path.join(path, "assets"))
# save the unique name object
generator = fluid.unique_name.generator
pklname = os.path.join(path, "uqn.pkl")
with open(pklname, "wb") as file:
pickle.dump(generator, file)
# save fluid Parameter # save fluid Parameter
param_arr = [] param_arr = []
for param in program.global_block().iter_parameters(): for param in program.global_block().iter_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册