未验证 提交 e8caffbb 编写于 作者: C Chen Weihang 提交者: GitHub

fix jit.save input_spec type change problem (#25683)

* fix jit.save input type change error

* add unittes
上级 364cc536
...@@ -653,8 +653,9 @@ def save(layer, model_path, input_spec=None, configs=None): ...@@ -653,8 +653,9 @@ def save(layer, model_path, input_spec=None, configs=None):
""" """
def get_inout_spec(all_vars, target_vars, return_name=False): def get_inout_spec(all_vars, target_vars, return_name=False):
valid_vars = [var for var in all_vars if isinstance(var, Variable)] result_list = []
valid_var_dict = {} valid_var_dict = {}
valid_vars = [var for var in all_vars if isinstance(var, Variable)]
for var in valid_vars: for var in valid_vars:
valid_var_dict[var.name] = var valid_var_dict[var.name] = var
if target_vars: if target_vars:
...@@ -663,13 +664,13 @@ def save(layer, model_path, input_spec=None, configs=None): ...@@ -663,13 +664,13 @@ def save(layer, model_path, input_spec=None, configs=None):
if var.name not in valid_var_dict: if var.name not in valid_var_dict:
raise RuntimeError( raise RuntimeError(
"The variable to feed/fetch are not exist.") "The variable to feed/fetch are not exist.")
target_vars[i] = valid_var_dict[var.name] result_list.append(valid_var_dict[var.name])
else: else:
target_vars = valid_vars result_list = valid_vars
if return_name: if return_name:
target_vars = [var.name for var in target_vars] result_list = [var.name for var in target_vars]
return target_vars return result_list
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
......
...@@ -114,8 +114,11 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -114,8 +114,11 @@ class TestJitSaveLoad(unittest.TestCase):
def train_and_save_model(self): def train_and_save_model(self):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer) example_inputs, layer, _ = train(layer)
orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs) layer=layer, model_path=self.model_path, input_spec=example_inputs)
new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types)
return layer return layer
def test_save(self): def test_save(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册