From 82f170b685804406289000830e3b7849cc48fbb7 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 23 Feb 2023 12:44:30 +0800 Subject: [PATCH] [Dy2static] bugfix load transformer (#50776) --- .../dygraph_to_static/test_load_transformer.py | 13 +++++++++++++ python/paddle/jit/dy2static/ast_transformer.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py index 52c5fcd4e4a..c7acf7f60a8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py @@ -53,5 +53,18 @@ class TestFallback(unittest.TestCase): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) +class TestLoad2(unittest.TestCase): + def test_name_load_nograd(self): + @paddle.no_grad() + def func(x): + x = paddle.shape(x) + return x + + x = paddle.to_tensor([[3, 3], [1, 1]]) + output_st = paddle.jit.to_static(func)(x) + output_dy = func(x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/jit/dy2static/ast_transformer.py b/python/paddle/jit/dy2static/ast_transformer.py index 3c7926d8fa6..0e935c58424 100644 --- a/python/paddle/jit/dy2static/ast_transformer.py +++ b/python/paddle/jit/dy2static/ast_transformer.py @@ -93,7 +93,6 @@ class DygraphToStaticAst(BaseTransformer): transformers = [ EarlyReturnTransformer, - DecoratorTransformer, # transform decorators to function call BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor) BreakContinueTransformer, # break/continue in loops @@ -105,6 +104,7 @@ class DygraphToStaticAst(BaseTransformer): AssertTransformer, # assert statement CallTransformer, # transform call recursively CastTransformer, # type casting statement + DecoratorTransformer, # transform decorators to function call NameloadJstTransformer, TypeHintTransformer, # remove all typehint in gast.Name ] -- GitLab