From 6e1fe4f1a8eb32aa6d6f3080d1ea22e08c3ee11b Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Fri, 18 Mar 2022 15:49:15 +0800 Subject: [PATCH] Support assign x.shape to dict['key'] in dy2st (#40611) * support assign x.shape to dict['key'] in dy2st * remove replace_dot * refine unit test --- .../dygraph_to_static/tensor_shape_transformer.py | 9 +-------- .../unittests/dygraph_to_static/test_tensor_shape.py | 8 ++++++++ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index e1df232488..7733226cc0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -297,10 +297,6 @@ class TensorShapeTransformer(gast.NodeTransformer): return False def _update_name_to_var_shape(self, node): - def replace_dot(name): - # replace all '.' into '_' - return name.replace('.', '_') - assert isinstance(node, gast.Assign) target_node = node.targets[0] value_node = node.value @@ -315,7 +311,6 @@ class TensorShapeTransformer(gast.NodeTransformer): if value_node.id in self.name_to_var_shape: # TODO(zhhsplendid): is context a problem for the result node of gast.parse? static_shape_var_name = unique_name.generate( - replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value @@ -337,7 +332,6 @@ class TensorShapeTransformer(gast.NodeTransformer): if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape static_shape_var_name = unique_name.generate( - replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value @@ -370,7 +364,6 @@ class TensorShapeTransformer(gast.NodeTransformer): if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: static_shape_var_name = unique_name.generate( - replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value @@ -387,7 +380,7 @@ class TensorShapeTransformer(gast.NodeTransformer): self.name_to_var_shape[target_id] = static_shape_var_name elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] static_shape_var_name = unique_name.generate( - replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse(static_shape_var_name).body[ 0].value static_shape_value_node = copy.deepcopy(value_node) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 06d69daa75..d05be03bbf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -223,6 +223,12 @@ def dyfunc_len_paddle_shape(): print(x) +def dyfunc_dict_assign_shape(): + x = paddle.to_tensor([1, 2]) + a = {} + a['shape'] = x.shape[0] + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -592,6 +598,8 @@ class TestPaddleShape(unittest.TestCase): def test_paddle_shape(self): func = paddle.jit.to_static(dyfunc_len_paddle_shape) self.assertEqual('paddle.shape(x)' in func.code, True) + func = paddle.jit.to_static(dyfunc_dict_assign_shape) + self.assertEqual("__static_convert_var_shape_suffix" in func.code, True) if __name__ == '__main__': -- GitLab