diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 6aa550426470f06d4dbf12786002f74150446340..98906d015808204bcca5b9f8acb4626217499364 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -31,13 +31,20 @@ def create_convert_shape_node(var_shape_node, if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] - if slice_node: + # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index + # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index + # In (1) case, we pass the number as 'idx' argument in convert_var_shape + # In (2) case, we have to make it like `convert_var_shape(x)[slice]` + if slice_node is not None and isinstance(slice_node, gast.Index): args.append(ast_to_source_code(slice_node).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) - api_shape_node = gast.parse(convert_var_shape_func).body[0].value + + if slice_node is not None and not isinstance(slice_node, gast.Index): + return gast.Subscript( + value=api_shape_node, slice=slice_node, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 17809ea16fd1f83ba2a0afa52f6411d40fed7f61..7a4c63894f9766f1aeda864ad5ff2b49f50b7e63 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x): return res +def dyfunc_tensor_shape_6(x): + # `res = fluid.layers.reshape(x, shape=(-1, s))` to + # `res = fluid.layers.reshape(x, shape=(-1, + # paddle.jit.dy2static.convert_var_shape(x)[0:]))` + x = fluid.dygraph.to_variable(x) + s = x.shape[0:] + res = fluid.layers.reshape(x, shape=s) + return res + + def dyfunc_tuple_shape_1(x): x = paddle.to_tensor(x) a, b = x.shape @@ -280,6 +290,11 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): self.dygraph_func = dyfunc_tensor_shape_5 +class TestTensorShapeBasic6(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_6 + + class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): self.input = numpy.ones((5, 7)).astype("int32")