提交 86470a5d 编写于 作者: W wuzewu

solve the problem of duplicate name by rename program var when create module

上级 8ce73b13
......@@ -43,7 +43,8 @@ MODEL_DIRNAME = "model"
DICT_FILENAME = "vocab.txt"
PARAM_FILENAME = "param.pkl"
MODULE_DESC_PBNAME = "module_desc.pb"
GENERATOR_FILENAME = "unique_name_generator.pkl"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
def mkdir(path):
......@@ -83,6 +84,7 @@ class Module(object):
with open(param_path, "rb") as file:
param_arr = pickle.load(file)
for param in param_arr:
param['name'] = HUB_VAR_PREFIX + param['name']
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
......@@ -146,9 +148,6 @@ class Module(object):
print("**feed_target_names**\n{}".format(self.feed_target_names))
print("**fetch_targets**\n{}".format(self.fetch_targets))
self._process_parameter()
name_generator_path = ModuleConfig.name_generator_path(self.module_dir)
with open(name_generator_path, "rb") as data:
generator = pickle.load(data)
program = self.get_inference_program().clone()
......@@ -156,14 +155,14 @@ class Module(object):
_set_param_trainable(program=program, trainable=trainable)
for key, value in feed_dict.items():
var = program.global_block().var(value)
var = program.global_block().var(HUB_VAR_PREFIX + value)
feed_dict[key] = var
for key, value in fetch_dict.items():
var = program.global_block().var(value)
var = program.global_block().var(HUB_VAR_PREFIX + value)
fetch_dict[key] = var
return feed_dict, fetch_dict, program, generator
return feed_dict, fetch_dict, program
def get_inference_program(self):
return self.inference_program
......@@ -253,12 +252,6 @@ class ModuleConfig(object):
def module_desc_path(module_dir):
return os.path.join(module_dir, MODULE_DESC_PBNAME)
@staticmethod
def name_generator_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, GENERATOR_FILENAME)
@staticmethod
def assets_dict_path(module_dir):
assets_path = os.path.join(module_dir, ASSETS_DIRNAME)
......@@ -271,12 +264,6 @@ class ModuleConfig(object):
mkdir(meta_path)
return os.path.join(meta_path, PARAM_FILENAME)
@staticmethod
def meta_name_generator_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, GENERATOR_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None):
""" Create a module from main program
......@@ -321,19 +308,6 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id))
# save the unique name generator object
var_name_arr = [
'_'.join(var.split('@')[0].split('.')[0].split('_')[0:-1])
for block in program.blocks for var in block.vars
]
with fluid.unique_name.guard():
for var_name in var_name_arr:
fluid.unique_name.generate(var_name)
generator = fluid.unique_name.generator
with open(ModuleConfig.name_generator_path(module_dir), "wb") as fo:
pickle.dump(generator, fo)
# save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters():
......@@ -386,6 +360,30 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
main_program=program,
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
mkdir(save_model_dir)
with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(save_model_dir, file),
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = module_desc.SerializeToString()
with open(ModuleConfig.module_desc_path(module_dir), "wb") as f:
......@@ -410,7 +408,8 @@ class ModuleUtils(object):
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
block._remove_var("feed")
block._remove_var("fetch")
# TODO(wuzewu): get feed and fetch var by other way
block._remove_var(HUB_VAR_PREFIX + "feed")
block._remove_var(HUB_VAR_PREFIX + "fetch")
program.desc.flush()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册