未验证 提交 d9c70217 编写于 作者: W WeiXin 提交者: GitHub

[cherry pick] Fix bug for 'save mutiple method' (#30218) (#30278)

* Fix bug for 'save mutiple method'

* To pass coverage.

* edit code to pass coverage.

* edit code to pass coverage.

* add unittest for coverage.

* change for coverage.

* edit for coverage.
上级 7c943a65
...@@ -600,11 +600,15 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -600,11 +600,15 @@ def _construct_program_holders(model_path, model_filename=None):
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith( elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name): model_name):
func_name = filename[len(model_name) + 1:-len( parsing_names = filename[len(model_name):-len(
INFER_MODEL_SUFFIX)] INFER_MODEL_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
model_file_path = os.path.join(model_path, filename) model_file_path = os.path.join(model_path, filename)
else: else:
continue continue
else:
continue
program_holder_dict[func_name] = _ProgramHolder( program_holder_dict[func_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path))
else: else:
...@@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path, ...@@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path,
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)] model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
#Load every file that meets the requirements in the directory model_path. #Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path): for file_name in os.listdir(model_path):
if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith( if file_name.startswith(model_name) and file_name.endswith(
model_name) and file_name != params_filename: INFER_PARAMS_SUFFIX):
func_name = file_name[len(model_name) + 1:-len( parsing_names = file_name[len(model_name):-len(
INFER_PARAMS_SUFFIX)] INFER_PARAMS_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
else:
continue
else: else:
continue continue
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
......
...@@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): ...@@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
paddle.jit.save( paddle.jit.save(
layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
def test_parse_name(self):
model_path_inference = "jit_save_load_parse_name/model"
IMAGE_SIZE = 224
layer = LinearNet(IMAGE_SIZE, 1)
inps = paddle.randn([1, IMAGE_SIZE])
layer(inps)
paddle.jit.save(layer, model_path_inference)
paddle.jit.save(layer, model_path_inference + '_v2')
load_net = paddle.jit.load(model_path_inference)
self.assertFalse(hasattr(load_net, 'v2'))
class LayerSaved(paddle.nn.Layer): class LayerSaved(paddle.nn.Layer):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册