未验证 提交 03e07273 编写于 作者: A Aurelius84 提交者: GitHub

Skip convert tensor shape while using Paddle.shape (#30223)

* fix tensor shape bug

* fix op_num

* clean code
上级 49411a20
...@@ -188,6 +188,14 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -188,6 +188,14 @@ class TensorShapeTransformer(gast.NodeTransformer):
return need_transformed return need_transformed
def _used_by_paddle_api(self, node): def _used_by_paddle_api(self, node):
"""
Whether node is used in paddle api as arguments.
For example:
1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
because the role of node is not arguments but `gast.Call.func`.
"""
assert isinstance(node, (gast.Attribute, gast.Name)) assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node) wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node: if not wrapper_node:
...@@ -196,7 +204,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -196,7 +204,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
while wrapper_node.parent: while wrapper_node.parent:
parent_node = wrapper_node.parent.node parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call): if isinstance(parent_node, gast.Call):
if is_paddle_api(parent_node): # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
if is_paddle_api(parent_node) and parent_node.func != node:
return True return True
else: else:
return False return False
......
...@@ -75,6 +75,17 @@ def dyfunc_tuple_shape_2(x): ...@@ -75,6 +75,17 @@ def dyfunc_tuple_shape_2(x):
return res return res
def dyfunc_paddle_shape_api(x):
x = paddle.to_tensor(x)
# paddle.shape will not be converted.
a = paddle.shape(x)[0]
# alias api will also not be converted.
alias_old_api = paddle.fluid.layers
b = alias_old_api.shape(x)[1]
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_with_if_1(x): def dyfunc_with_if_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, [-1, 1]) res = fluid.layers.reshape(x, [-1, 1])
...@@ -283,6 +294,18 @@ class TestTupleShape2(TestTensorShapeBasic): ...@@ -283,6 +294,18 @@ class TestTupleShape2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tuple_shape_2 self.dygraph_func = dyfunc_tuple_shape_2
class TestPaddleShapeApi(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_paddle_shape_api
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
# 2. Tests with control flow if # 2. Tests with control flow if
class TestTensorShapeInIf1(TestTensorShapeBasic): class TestTensorShapeInIf1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册