未验证 提交 50822491 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Fix error in tensor_shape_transformer. (#37999)

* fix error when tensor_shape_transformer. Before in stmt like `if len(paddle.shape(x)[0]) > 0`, `paddle` will be used as a variable

* handle other call like `fluid.layers.mean` and `fluid.layers.shape`

* add unit test
上级 141b2854
......@@ -282,6 +282,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
if isinstance(node, gast.Attribute):
# If node is `paddle.shape`, return False
if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
node.value.id == 'paddle'):
return False
if node.attr != 'shape':
return False
return True
......
......@@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x):
return res
def dyfunc_len_paddle_shape():
x = paddle.to_tensor([1, 2, 3])
if len(paddle.shape(x)) > 0:
print(x)
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
......@@ -582,5 +588,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
func.concrete_program
class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册