未验证 提交 e8caffbb 编写于 作者: C Chen Weihang 提交者: GitHub

fix jit.save input_spec type change problem (#25683)

* fix jit.save input type change error

* add unittes
上级 364cc536
......@@ -653,8 +653,9 @@ def save(layer, model_path, input_spec=None, configs=None):
"""
def get_inout_spec(all_vars, target_vars, return_name=False):
valid_vars = [var for var in all_vars if isinstance(var, Variable)]
result_list = []
valid_var_dict = {}
valid_vars = [var for var in all_vars if isinstance(var, Variable)]
for var in valid_vars:
valid_var_dict[var.name] = var
if target_vars:
......@@ -663,13 +664,13 @@ def save(layer, model_path, input_spec=None, configs=None):
if var.name not in valid_var_dict:
raise RuntimeError(
"The variable to feed/fetch are not exist.")
target_vars[i] = valid_var_dict[var.name]
result_list.append(valid_var_dict[var.name])
else:
target_vars = valid_vars
result_list = valid_vars
if return_name:
target_vars = [var.name for var in target_vars]
result_list = [var.name for var in target_vars]
return target_vars
return result_list
# 1. input check
prog_translator = ProgramTranslator()
......
......@@ -114,8 +114,11 @@ class TestJitSaveLoad(unittest.TestCase):
def train_and_save_model(self):
layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer)
orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs)
new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types)
return layer
def test_save(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册