diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 4b57301388ef7003586c345a9352389d427feb26..d54916a72c898076185f163614d50ec9ff4ff607 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -550,6 +550,7 @@ class ProgramTranslator(object): source_code = ast_to_source_code(root_wrapper.node) return source_code + @switch_to_static_graph def save_inference_model(self, dirname, feed=None, fetch=None): """ Saves current model as the inference model. It will prune the main_program diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py index 2880dd00559c33aa6dc9987d715a9cf87ac8e1cb..722b0f14fa060cb843aedf25e9b1ca9ee0ac4407 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist.py @@ -200,7 +200,6 @@ class TestMNISTWithDeclarative(TestMNIST): break return loss_data - @switch_to_static_graph def check_save_inference_model(self, inputs, prog_trans, to_static, gt_out): if to_static: infer_model_path = "./test_mnist_inference_model" @@ -208,6 +207,7 @@ class TestMNISTWithDeclarative(TestMNIST): infer_out = self.load_and_run_inference(infer_model_path, inputs) self.assertTrue(np.allclose(gt_out.numpy(), infer_out)) + @switch_to_static_graph def load_and_run_inference(self, model_path, inputs): exe = fluid.Executor(self.place) [inference_program, feed_target_names, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py index 7414d240bf5727886f0d8ae5b493194f0bdd25e3..180ada7b9a769731e82db239dc696e23c13feed5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py @@ -30,6 +30,7 @@ np.random.seed(SEED) place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ) +program_translator = ProgramTranslator() class SimpleFcLayer(fluid.dygraph.Layer): @@ -63,6 +64,10 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): loss.backward() adam.minimize(loss) layer.clear_gradients() + # test for saving model in dygraph.guard + infer_model_dir = "./test_dy2stat_save_inference_model" + program_translator.save_inference_model( + infer_model_dir, feed=[0], fetch=[1]) # Check the correctness of the inference dygraph_out, _ = layer(x) self.check_save_inference_model(layer, [x_data], dygraph_out.numpy()) @@ -77,7 +82,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): gt_out, feed=None, fetch=None): - program_translator = ProgramTranslator() + expected_persistable_vars = set([p.name for p in model.parameters()]) infer_model_dir = "./test_dy2stat_save_inference_model"