未验证 提交 5b339262 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Tuple as Assign Target for Tensor Shape (#28775)

Add support for using tuple as tensor.shape (For example: a, b, c, d = x.shape)
上级 5cb8e17a
......@@ -17,6 +17,7 @@ from __future__ import print_function
import copy
import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
......@@ -192,12 +193,39 @@ class TensorShapeTransformer(gast.NodeTransformer):
def _update_name_to_var_shape(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(target_node, gast.Tuple):
has_updated = False
for idx, element in enumerate(target_node.elts):
target_id = ast_to_source_code(element).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
var_shape_node = self.name_to_var_shape[value_node.id]
sub_node = gast.Subscript(
value=var_shape_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
if isinstance(value_node, gast.Attribute):
if self.is_var_shape(value_node): # eg: x.shape
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript(
value=value_node,
slice=slice_index_node,
ctx=gast.Load())
self.name_to_var_shape[target_id] = sub_node
has_updated = True
return has_updated
else:
target_id = ast_to_source_code(target_node).strip()
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
self.name_to_var_shape[target_id] = self.name_to_var_shape[
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
......@@ -59,6 +60,21 @@ def dyfunc_tensor_shape_5(x):
return res
def dyfunc_tuple_shape_1(x):
x = paddle.to_tensor(x)
a, b = x.shape
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_tuple_shape_2(x):
x = paddle.to_tensor(x)
shape = x.shape
a, b = shape
res = paddle.reshape(x, shape=(b, a))
return res
def dyfunc_with_if_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, [-1, 1])
......@@ -224,6 +240,18 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_5
class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.dygraph_func = dyfunc_tuple_shape_1
class TestTupleShape2(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
self.dygraph_func = dyfunc_tuple_shape_2
# 2. Tests with control flow if
class TestTensorShapeInIf1(TestTensorShapeBasic):
def init_test_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册