未验证 提交 420944e5 编写于 作者: H hong 提交者: GitHub

enhance load dygraph; test=develop (#23167)

上级 1ee2a9a4
...@@ -125,7 +125,13 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -125,7 +125,13 @@ def load_dygraph(model_path, keep_name_table=False):
''' '''
params_file_path = model_path + ".pdparams" model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]
params_file_path = model_prefix + ".pdparams"
if not os.path.exists(params_file_path): if not os.path.exists(params_file_path):
raise RuntimeError("Parameter file [ {} ] not exists".format( raise RuntimeError("Parameter file [ {} ] not exists".format(
params_file_path)) params_file_path))
...@@ -137,7 +143,7 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -137,7 +143,7 @@ def load_dygraph(model_path, keep_name_table=False):
if not keep_name_table and "StructuredToParameterName@@" in para_dict: if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
opti_dict = None opti_dict = None
opti_file_path = model_path + ".pdopt" opti_file_path = model_prefix + ".pdopt"
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f: with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f) if six.PY2 else pickle.load(
......
...@@ -887,6 +887,12 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -887,6 +887,12 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(opti_state_dict == None) self.assertTrue(opti_state_dict == None)
para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy.pdparams'))
para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy.pdopt'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册