未验证 提交 5b339262 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Tuple as Assign Target for Tensor Shape (#28775)

Add support for using tuple as tensor.shape (For example: a, b, c, d = x.shape)
上级 5cb8e17a
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import copy import copy
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
...@@ -192,24 +193,51 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -192,24 +193,51 @@ class TensorShapeTransformer(gast.NodeTransformer):
def _update_name_to_var_shape(self, node): def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign) assert isinstance(node, gast.Assign)
target_node = node.targets[0] target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value value_node = node.value
if isinstance(value_node, gast.Name): if isinstance(target_node, gast.Tuple):
if value_node.id in self.name_to_var_shape: has_updated = False
self.name_to_var_shape[target_id] = self.name_to_var_shape[ for idx, element in enumerate(target_node.elts):
value_node.id] target_id = ast_to_source_code(element).strip()
return True
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Name):
if self.is_var_shape(value_node): # eg: x.shape if value_node.id in self.name_to_var_shape:
self.name_to_var_shape[target_id] = value_node index_value_node = gast.Constant(value=idx, kind=None)
return True slice_index_node = gast.Index(value=index_value_node)
if isinstance(value_node, gast.Subscript): var_shape_node = self.name_to_var_shape[value_node.id]
if isinstance(value_node.value, gast.Attribute): sub_node = gast.Subscript(
if self.is_var_shape(value_node.value): # eg: x.shape[0] value=var_shape_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
if isinstance(value_node, gast.Attribute):
if self.is_var_shape(value_node): # eg: x.shape
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript(
value=value_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
return has_updated
else:
target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
self.name_to_var_shape[target_id] = self.name_to_var_shape[
value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_var_shape(value_node): # eg: x.shape
self.name_to_var_shape[target_id] = value_node self.name_to_var_shape[target_id] = value_node
return True return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_var_shape(value_node.value): # eg: x.shape[0]
self.name_to_var_shape[target_id] = value_node
return True
return False return False
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy import numpy
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
...@@ -59,6 +60,21 @@ def dyfunc_tensor_shape_5(x): ...@@ -59,6 +60,21 @@ def dyfunc_tensor_shape_5(x):
return res return res
def dyfunc_tuple_shape_1(x):
x = paddle.to_tensor(x)
a, b = x.shape
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_tuple_shape_2(x):
x = paddle.to_tensor(x)
shape = x.shape
a, b = shape
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_with_if_1(x): def dyfunc_with_if_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, [-1, 1]) res = fluid.layers.reshape(x, [-1, 1])
...@@ -224,6 +240,18 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): ...@@ -224,6 +240,18 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_5 self.dygraph_func = dyfunc_tensor_shape_5
class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.dygraph_func = dyfunc_tuple_shape_1
class TestTupleShape2(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.dygraph_func = dyfunc_tuple_shape_2
# 2. Tests with control flow if # 2. Tests with control flow if
class TestTensorShapeInIf1(TestTensorShapeBasic): class TestTensorShapeInIf1(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册