diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index a2c48921deebcb6a23f2fee9177bf50924922c29..af4ba16ee8f64cb6293dd11413492e013c32b99d 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -600,9 +600,13 @@ def _construct_program_holders(model_path, model_filename=None): model_file_path = os.path.join(model_path, model_filename) elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith( model_name): - func_name = filename[len(model_name) + 1:-len( - INFER_MODEL_SUFFIX)] - model_file_path = os.path.join(model_path, filename) + parsing_names = filename[len(model_name):-len( + INFER_MODEL_SUFFIX) + 1].split('.') + if len(parsing_names) == 3 and len(parsing_names[1]) > 0: + func_name = parsing_names[1] + model_file_path = os.path.join(model_path, filename) + else: + continue else: continue program_holder_dict[func_name] = _ProgramHolder( @@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path, model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)] #Load every file that meets the requirements in the directory model_path. for file_name in os.listdir(model_path): - if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith( - model_name) and file_name != params_filename: - func_name = file_name[len(model_name) + 1:-len( - INFER_PARAMS_SUFFIX)] + if file_name.startswith(model_name) and file_name.endswith( + INFER_PARAMS_SUFFIX): + parsing_names = file_name[len(model_name):-len( + INFER_PARAMS_SUFFIX) + 1].split('.') + if len(parsing_names) == 3 and len(parsing_names[1]) > 0: + func_name = parsing_names[1] + else: + continue else: continue var_info_path = os.path.join(model_path, var_info_filename) diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index dead4a19a61dad29b3896beba53901f060196b68..b2704085fd42cb4f4da3a7d07ba8e222db8ad663 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): paddle.jit.save( layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) + def test_parse_name(self): + model_path_inference = "jit_save_load_parse_name/model" + IMAGE_SIZE = 224 + layer = LinearNet(IMAGE_SIZE, 1) + inps = paddle.randn([1, IMAGE_SIZE]) + layer(inps) + paddle.jit.save(layer, model_path_inference) + paddle.jit.save(layer, model_path_inference + '_v2') + load_net = paddle.jit.load(model_path_inference) + + self.assertFalse(hasattr(load_net, 'v2')) + class LayerSaved(paddle.nn.Layer): def __init__(self, in_size, out_size):