From 5b339262bc25041cf0208ffa8b8a63d9d26d1b16 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Tue, 24 Nov 2020 17:13:24 +0800 Subject: [PATCH] [Dy2stat] Add Tuple as Assign Target for Tensor Shape (#28775) Add support for using tuple as tensor.shape (For example: a, b, c, d = x.shape) --- .../tensor_shape_transformer.py | 60 ++++++++++++++----- .../dygraph_to_static/test_tensor_shape.py | 28 +++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 6cdf2799624..31de609e9fc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -17,6 +17,7 @@ from __future__ import print_function import copy import gast +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper @@ -192,24 +193,51 @@ class TensorShapeTransformer(gast.NodeTransformer): def _update_name_to_var_shape(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] - try: - target_id = target_node.id - except AttributeError: - return False value_node = node.value - if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_var_shape: - self.name_to_var_shape[target_id] = self.name_to_var_shape[ - value_node.id] - return True - if isinstance(value_node, gast.Attribute): - if self.is_var_shape(value_node): # eg: x.shape - self.name_to_var_shape[target_id] = value_node - return True - if isinstance(value_node, gast.Subscript): - if isinstance(value_node.value, gast.Attribute): - if self.is_var_shape(value_node.value): # eg: x.shape[0] + if isinstance(target_node, gast.Tuple): + has_updated = False + for idx, element in enumerate(target_node.elts): + target_id = ast_to_source_code(element).strip() + + if isinstance(value_node, gast.Name): + if value_node.id in self.name_to_var_shape: + index_value_node = gast.Constant(value=idx, kind=None) + slice_index_node = gast.Index(value=index_value_node) + var_shape_node = self.name_to_var_shape[value_node.id] + sub_node = gast.Subscript( + value=var_shape_node, + slice=slice_index_node, + ctx=gast.Load()) + self.name_to_var_shape[target_id] = sub_node + has_updated = True + if isinstance(value_node, gast.Attribute): + if self.is_var_shape(value_node): # eg: x.shape + index_value_node = gast.Constant(value=idx, kind=None) + slice_index_node = gast.Index(value=index_value_node) + sub_node = gast.Subscript( + value=value_node, + slice=slice_index_node, + ctx=gast.Load()) + self.name_to_var_shape[target_id] = sub_node + has_updated = True + + return has_updated + else: + target_id = ast_to_source_code(target_node).strip() + + if isinstance(value_node, gast.Name): + if value_node.id in self.name_to_var_shape: + self.name_to_var_shape[target_id] = self.name_to_var_shape[ + value_node.id] + return True + if isinstance(value_node, gast.Attribute): + if self.is_var_shape(value_node): # eg: x.shape self.name_to_var_shape[target_id] = value_node return True + if isinstance(value_node, gast.Subscript): + if isinstance(value_node.value, gast.Attribute): + if self.is_var_shape(value_node.value): # eg: x.shape[0] + self.name_to_var_shape[target_id] = value_node + return True return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index de9554a2d4a..53dbb07c97f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -17,6 +17,7 @@ from __future__ import print_function import numpy import unittest +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.jit import declarative @@ -59,6 +60,21 @@ def dyfunc_tensor_shape_5(x): return res +def dyfunc_tuple_shape_1(x): + x = paddle.to_tensor(x) + a, b = x.shape + res = paddle.reshape(x, shape=(b, a)) + return res + + +def dyfunc_tuple_shape_2(x): + x = paddle.to_tensor(x) + shape = x.shape + a, b = shape + res = paddle.reshape(x, shape=(b, a)) + return res + + def dyfunc_with_if_1(x): x = fluid.dygraph.to_variable(x) res = fluid.layers.reshape(x, [-1, 1]) @@ -224,6 +240,18 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_5 +class TestTupleShape1(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((5, 7)).astype("int32") + self.dygraph_func = dyfunc_tuple_shape_1 + + +class TestTupleShape2(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((5, 7)).astype("int32") + self.dygraph_func = dyfunc_tuple_shape_2 + + # 2. Tests with control flow if class TestTensorShapeInIf1(TestTensorShapeBasic): def init_test_func(self): -- GitLab