未验证 提交 82f170b6 编写于 作者: X xiongkun 提交者: GitHub

[Dy2static] bugfix load transformer (#50776)

上级 f43b5fe5
......@@ -53,5 +53,18 @@ class TestFallback(unittest.TestCase):
np.testing.assert_allclose(output_dy.numpy(), output_st.numpy())
class TestLoad2(unittest.TestCase):
def test_name_load_nograd(self):
@paddle.no_grad()
def func(x):
x = paddle.shape(x)
return x
x = paddle.to_tensor([[3, 3], [1, 1]])
output_st = paddle.jit.to_static(func)(x)
output_dy = func(x)
np.testing.assert_allclose(output_dy.numpy(), output_st.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -93,7 +93,6 @@ class DygraphToStaticAst(BaseTransformer):
transformers = [
EarlyReturnTransformer,
DecoratorTransformer, # transform decorators to function call
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor)
BreakContinueTransformer, # break/continue in loops
......@@ -105,6 +104,7 @@ class DygraphToStaticAst(BaseTransformer):
AssertTransformer, # assert statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
NameloadJstTransformer,
TypeHintTransformer, # remove all typehint in gast.Name
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册