From 420944e5142a92a7648c19e4dff2b7af13f8d724 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Tue, 31 Mar 2020 17:04:54 +0800 Subject: [PATCH] enhance load dygraph; test=develop (#23167) --- python/paddle/fluid/dygraph/checkpoint.py | 10 ++++++++-- .../fluid/tests/unittests/test_imperative_save_load.py | 6 ++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index a51c431c51..870404e64b 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -125,7 +125,13 @@ def load_dygraph(model_path, keep_name_table=False): ''' - params_file_path = model_path + ".pdparams" + model_prefix = model_path + if model_prefix.endswith(".pdparams"): + model_prefix = model_prefix[:-9] + elif model_prefix.endswith(".pdopt"): + model_prefix = model_prefix[:-6] + + params_file_path = model_prefix + ".pdparams" if not os.path.exists(params_file_path): raise RuntimeError("Parameter file [ {} ] not exists".format( params_file_path)) @@ -137,7 +143,7 @@ def load_dygraph(model_path, keep_name_table=False): if not keep_name_table and "StructuredToParameterName@@" in para_dict: del para_dict["StructuredToParameterName@@"] opti_dict = None - opti_file_path = model_path + ".pdopt" + opti_file_path = model_prefix + ".pdopt" if os.path.exists(opti_file_path): with open(opti_file_path, 'rb') as f: opti_dict = pickle.load(f) if six.PY2 else pickle.load( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index 6a621b8c75..694ad077c0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -887,6 +887,12 @@ class TestDygraphPtbRnn(unittest.TestCase): self.assertTrue(opti_state_dict == None) + para_state_dict, opti_state_dict = fluid.load_dygraph( + os.path.join('saved_dy', 'emb_dy.pdparams')) + + para_state_dict, opti_state_dict = fluid.load_dygraph( + os.path.join('saved_dy', 'emb_dy.pdopt')) + if __name__ == '__main__': unittest.main() -- GitLab