From 508224913ccca5ccb7c93b6c436df3d8a0ac8d71 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 15 Dec 2021 17:36:36 +0800 Subject: [PATCH] [Dy2stat]Fix error in tensor_shape_transformer. (#37999) * fix error when tensor_shape_transformer. Before in stmt like `if len(paddle.shape(x)[0]) > 0`, `paddle` will be used as a variable * handle other call like `fluid.layers.mean` and `fluid.layers.shape` * add unit test --- .../dygraph_to_static/tensor_shape_transformer.py | 4 ++++ .../unittests/dygraph_to_static/test_tensor_shape.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 0bc167132e..e1df232488 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -282,6 +282,10 @@ class TensorShapeTransformer(gast.NodeTransformer): return False if isinstance(node, gast.Attribute): + # If node is `paddle.shape`, return False + if (node.attr == 'shape' and isinstance(node.value, gast.Name) and + node.value.id == 'paddle'): + return False if node.attr != 'shape': return False return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index f7cdb12a1a..06d69daa75 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x): return res +def dyfunc_len_paddle_shape(): + x = paddle.to_tensor([1, 2, 3]) + if len(paddle.shape(x)) > 0: + print(x) + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -582,5 +588,11 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): func.concrete_program +class TestPaddleShape(unittest.TestCase): + def test_paddle_shape(self): + func = paddle.jit.to_static(dyfunc_len_paddle_shape) + self.assertEqual('paddle.shape(x)' in func.code, True) + + if __name__ == '__main__': unittest.main() -- GitLab