diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py index 52c5fcd4e4aebbbe8c77cea030d7fb81de93938e..c7acf7f60a88a3b07c784079b7b631c87e546fad 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py @@ -53,5 +53,18 @@ class TestFallback(unittest.TestCase): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) +class TestLoad2(unittest.TestCase): + def test_name_load_nograd(self): + @paddle.no_grad() + def func(x): + x = paddle.shape(x) + return x + + x = paddle.to_tensor([[3, 3], [1, 1]]) + output_st = paddle.jit.to_static(func)(x) + output_dy = func(x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/jit/dy2static/ast_transformer.py b/python/paddle/jit/dy2static/ast_transformer.py index 3c7926d8fa621e1fb624614056e94ee59768d11e..0e935c5842433f0d14314db2672b7acdf64b5a58 100644 --- a/python/paddle/jit/dy2static/ast_transformer.py +++ b/python/paddle/jit/dy2static/ast_transformer.py @@ -93,7 +93,6 @@ class DygraphToStaticAst(BaseTransformer): transformers = [ EarlyReturnTransformer, - DecoratorTransformer, # transform decorators to function call BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor) BreakContinueTransformer, # break/continue in loops @@ -105,6 +104,7 @@ class DygraphToStaticAst(BaseTransformer): AssertTransformer, # assert statement CallTransformer, # transform call recursively CastTransformer, # type casting statement + DecoratorTransformer, # transform decorators to function call NameloadJstTransformer, TypeHintTransformer, # remove all typehint in gast.Name ]