From 55604248ded67d649a9ff646b3410f7f83ab7f73 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 11 Jan 2021 14:47:27 +0800 Subject: [PATCH] Skip convert tensor shape while using Paddle.shape (#30223) (#30239) * fix tensor shape bug * fix op_num * clean code --- .../tensor_shape_transformer.py | 11 ++++++++- .../dygraph_to_static/test_tensor_shape.py | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 7c45c10a48e..6aa55042647 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -188,6 +188,14 @@ class TensorShapeTransformer(gast.NodeTransformer): return need_transformed def _used_by_paddle_api(self, node): + """ + Whether node is used in paddle api as arguments. + For example: + 1) Return True in `paddle.relu(x)` where node is `x` (gast.Name) + 2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute) + 3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute), + because the role of node is not arguments but `gast.Call.func`. + """ assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: @@ -196,7 +204,8 @@ class TensorShapeTransformer(gast.NodeTransformer): while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): - if is_paddle_api(parent_node): + # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`. + if is_paddle_api(parent_node) and parent_node.func != node: return True else: return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index dfc8d2429f4..17809ea16fd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -75,6 +75,17 @@ def dyfunc_tuple_shape_2(x): return res +def dyfunc_paddle_shape_api(x): + x = paddle.to_tensor(x) + # paddle.shape will not be converted. + a = paddle.shape(x)[0] + # alias api will also not be converted. + alias_old_api = paddle.fluid.layers + b = alias_old_api.shape(x)[1] + res = paddle.reshape(x, shape=(b, a)) + return res + + def dyfunc_with_if_1(x): x = fluid.dygraph.to_variable(x) res = fluid.layers.reshape(x, [-1, 1]) @@ -283,6 +294,18 @@ class TestTupleShape2(TestTensorShapeBasic): self.dygraph_func = dyfunc_tuple_shape_2 +class TestPaddleShapeApi(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((5, 7)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] + self.dygraph_func = dyfunc_paddle_shape_api + + def _set_expected_op_num(self): + self.expected_op_num = 6 + self.expected_shape_op_num = 2 + self.expected_slice_op_num = 2 + + # 2. Tests with control flow if class TestTensorShapeInIf1(TestTensorShapeBasic): def init_test_func(self): -- GitLab