From d9c702177dda9f3ee2c1f5e3be823ded628fbd13 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Mon, 11 Jan 2021 18:56:58 +0800 Subject: [PATCH] [cherry pick] Fix bug for 'save mutiple method' (#30218) (#30278) * Fix bug for 'save mutiple method' * To pass coverage. * edit code to pass coverage. * edit code to pass coverage. * add unittest for coverage. * change for coverage. * edit for coverage. --- python/paddle/fluid/dygraph/io.py | 22 +++++++++++++------ .../tests/unittests/test_jit_save_load.py | 12 ++++++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index a2c48921de..af4ba16ee8 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -600,9 +600,13 @@ def _construct_program_holders(model_path, model_filename=None): model_file_path = os.path.join(model_path, model_filename) elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith( model_name): - func_name = filename[len(model_name) + 1:-len( - INFER_MODEL_SUFFIX)] - model_file_path = os.path.join(model_path, filename) + parsing_names = filename[len(model_name):-len( + INFER_MODEL_SUFFIX) + 1].split('.') + if len(parsing_names) == 3 and len(parsing_names[1]) > 0: + func_name = parsing_names[1] + model_file_path = os.path.join(model_path, filename) + else: + continue else: continue program_holder_dict[func_name] = _ProgramHolder( @@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path, model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)] #Load every file that meets the requirements in the directory model_path. for file_name in os.listdir(model_path): - if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith( - model_name) and file_name != params_filename: - func_name = file_name[len(model_name) + 1:-len( - INFER_PARAMS_SUFFIX)] + if file_name.startswith(model_name) and file_name.endswith( + INFER_PARAMS_SUFFIX): + parsing_names = file_name[len(model_name):-len( + INFER_PARAMS_SUFFIX) + 1].split('.') + if len(parsing_names) == 3 and len(parsing_names[1]) > 0: + func_name = parsing_names[1] + else: + continue else: continue var_info_path = os.path.join(model_path, var_info_filename) diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index dead4a19a6..b2704085fd 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): paddle.jit.save( layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) + def test_parse_name(self): + model_path_inference = "jit_save_load_parse_name/model" + IMAGE_SIZE = 224 + layer = LinearNet(IMAGE_SIZE, 1) + inps = paddle.randn([1, IMAGE_SIZE]) + layer(inps) + paddle.jit.save(layer, model_path_inference) + paddle.jit.save(layer, model_path_inference + '_v2') + load_net = paddle.jit.load(model_path_inference) + + self.assertFalse(hasattr(load_net, 'v2')) + class LayerSaved(paddle.nn.Layer): def __init__(self, in_size, out_size): -- GitLab