未验证 提交 edafb546 编写于 作者: W WeiXin 提交者: GitHub

Fix bug for 'save mutiple method' (#30218)

* Fix bug for 'save mutiple method'

* To pass coverage.

* edit code to pass coverage.

* edit code to pass coverage.

* add unittest for coverage.

* change for coverage.

* edit for coverage.
上级 66dc4ac7
......@@ -600,9 +600,13 @@ def _construct_program_holders(model_path, model_filename=None):
model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name):
func_name = filename[len(model_name) + 1:-len(
INFER_MODEL_SUFFIX)]
model_file_path = os.path.join(model_path, filename)
parsing_names = filename[len(model_name):-len(
INFER_MODEL_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
model_file_path = os.path.join(model_path, filename)
else:
continue
else:
continue
program_holder_dict[func_name] = _ProgramHolder(
......@@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path,
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
#Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path):
if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith(
model_name) and file_name != params_filename:
func_name = file_name[len(model_name) + 1:-len(
INFER_PARAMS_SUFFIX)]
if file_name.startswith(model_name) and file_name.endswith(
INFER_PARAMS_SUFFIX):
parsing_names = file_name[len(model_name):-len(
INFER_PARAMS_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
else:
continue
else:
continue
var_info_path = os.path.join(model_path, var_info_filename)
......
......@@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
paddle.jit.save(
layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
def test_parse_name(self):
model_path_inference = "jit_save_load_parse_name/model"
IMAGE_SIZE = 224
layer = LinearNet(IMAGE_SIZE, 1)
inps = paddle.randn([1, IMAGE_SIZE])
layer(inps)
paddle.jit.save(layer, model_path_inference)
paddle.jit.save(layer, model_path_inference + '_v2')
load_net = paddle.jit.load(model_path_inference)
self.assertFalse(hasattr(load_net, 'v2'))
class LayerSaved(paddle.nn.Layer):
def __init__(self, in_size, out_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册