From 19eb51002b74341e24fd2900ad6be44a90714943 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Thu, 16 Dec 2021 19:43:45 +0800 Subject: [PATCH] [Dy2stat]Fix error in tensor_shape_transformer. (#37999) (#38168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复tensor_shape_transformer中的错误。 之前在类似if len(paddle.shape(x)[0]) > 0中,paddle会被当做一个变量被传入convert_var_shape函数中 --- .../dygraph_to_static/tensor_shape_transformer.py | 4 ++++ .../unittests/dygraph_to_static/test_tensor_shape.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 0bc167132e3..e1df2324889 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -282,6 +282,10 @@ class TensorShapeTransformer(gast.NodeTransformer): return False if isinstance(node, gast.Attribute): + # If node is `paddle.shape`, return False + if (node.attr == 'shape' and isinstance(node.value, gast.Name) and + node.value.id == 'paddle'): + return False if node.attr != 'shape': return False return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index f7cdb12a1ab..06d69daa75d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x): return res +def dyfunc_len_paddle_shape(): + x = paddle.to_tensor([1, 2, 3]) + if len(paddle.shape(x)) > 0: + print(x) + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -582,5 +588,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): func.concrete_program +class TestPaddleShape(unittest.TestCase): + def test_paddle_shape(self): + func = paddle.jit.to_static(dyfunc_len_paddle_shape) + self.assertEqual('paddle.shape(x)' in func.code, True) + + if __name__ == '__main__': unittest.main() -- GitLab