From e8caffbb4a74a58cc3d474995ba4f5b13465831a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 24 Jul 2020 15:12:49 +0800 Subject: [PATCH] fix jit.save input_spec type change problem (#25683) * fix jit.save input type change error * add unittes --- python/paddle/fluid/dygraph/jit.py | 11 ++++++----- .../fluid/tests/unittests/test_jit_save_load.py | 3 +++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index bd468b55d8..754a0b67fe 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -653,8 +653,9 @@ def save(layer, model_path, input_spec=None, configs=None): """ def get_inout_spec(all_vars, target_vars, return_name=False): - valid_vars = [var for var in all_vars if isinstance(var, Variable)] + result_list = [] valid_var_dict = {} + valid_vars = [var for var in all_vars if isinstance(var, Variable)] for var in valid_vars: valid_var_dict[var.name] = var if target_vars: @@ -663,13 +664,13 @@ def save(layer, model_path, input_spec=None, configs=None): if var.name not in valid_var_dict: raise RuntimeError( "The variable to feed/fetch are not exist.") - target_vars[i] = valid_var_dict[var.name] + result_list.append(valid_var_dict[var.name]) else: - target_vars = valid_vars + result_list = valid_vars if return_name: - target_vars = [var.name for var in target_vars] + result_list = [var.name for var in target_vars] - return target_vars + return result_list # 1. input check prog_translator = ProgramTranslator() diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 640e966354..abc4603495 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -114,8 +114,11 @@ class TestJitSaveLoad(unittest.TestCase): def train_and_save_model(self): layer = LinearNet(784, 1) example_inputs, layer, _ = train(layer) + orig_input_types = [type(x) for x in example_inputs] fluid.dygraph.jit.save( layer=layer, model_path=self.model_path, input_spec=example_inputs) + new_input_types = [type(x) for x in example_inputs] + self.assertEqual(orig_input_types, new_input_types) return layer def test_save(self): -- GitLab