未验证 提交 19eb5100 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Fix error in tensor_shape_transformer. (#37999) (#38168)

修复tensor_shape_transformer中的错误。
之前在类似if len(paddle.shape(x)[0]) > 0中,paddle会被当做一个变量被传入convert_var_shape函数中
上级 8100c16a
......@@ -282,6 +282,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
if isinstance(node, gast.Attribute):
# If node is `paddle.shape`, return False
if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
node.value.id == 'paddle'):
return False
if node.attr != 'shape':
return False
return True
......
......@@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x):
return res
def dyfunc_len_paddle_shape():
x = paddle.to_tensor([1, 2, 3])
if len(paddle.shape(x)) > 0:
print(x)
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
......@@ -582,5 +588,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
func.concrete_program
class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册