From ef627ac5b9223284a8813239e57fa9ef1a53b710 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Sat, 20 Feb 2021 10:19:25 +0800 Subject: [PATCH] Fix that convert_var_shape doesn't support slice like [0:], test=develop (#31051) As the title, when slice_node like 1:3 being passed to idx of convert_var_shape, it will cause syntax error because a function cannot take this as argument. This PR fixed it. --- .../dygraph_to_static/tensor_shape_transformer.py | 11 +++++++++-- .../dygraph_to_static/test_tensor_shape.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 6aa5504264..98906d0158 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -31,13 +31,20 @@ def create_convert_shape_node(var_shape_node, if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] - if slice_node: + # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index + # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index + # In (1) case, we pass the number as 'idx' argument in convert_var_shape + # In (2) case, we have to make it like `convert_var_shape(x)[slice]` + if slice_node is not None and isinstance(slice_node, gast.Index): args.append(ast_to_source_code(slice_node).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) - api_shape_node = gast.parse(convert_var_shape_func).body[0].value + + if slice_node is not None and not isinstance(slice_node, gast.Index): + return gast.Subscript( + value=api_shape_node, slice=slice_node, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 17809ea16f..7a4c63894f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x): return res +def dyfunc_tensor_shape_6(x): + # `res = fluid.layers.reshape(x, shape=(-1, s))` to + # `res = fluid.layers.reshape(x, shape=(-1, + # paddle.jit.dy2static.convert_var_shape(x)[0:]))` + x = fluid.dygraph.to_variable(x) + s = x.shape[0:] + res = fluid.layers.reshape(x, shape=s) + return res + + def dyfunc_tuple_shape_1(x): x = paddle.to_tensor(x) a, b = x.shape @@ -280,6 +290,11 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_5 +class TestTensorShapeBasic6(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_6 + + class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32") -- GitLab