未验证 提交 ef627ac5 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix that convert_var_shape doesn't support slice like [0:], test=develop (#31051)

As the title, when slice_node like 1:3 being passed to idx of convert_var_shape, it will cause syntax error because a function cannot take this as argument. This PR fixed it.
上级 f7465641
......@@ -31,13 +31,20 @@ def create_convert_shape_node(var_shape_node,
if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()]
if slice_node:
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
return gast.Subscript(
value=api_shape_node, slice=slice_node, ctx=gast.Load())
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
......
......@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
return res
def dyfunc_tensor_shape_6(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1,
# paddle.jit.dy2static.convert_var_shape(x)[0:]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0:]
res = fluid.layers.reshape(x, shape=s)
return res
def dyfunc_tuple_shape_1(x):
x = paddle.to_tensor(x)
a, b = x.shape
......@@ -280,6 +290,11 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_5
class TestTensorShapeBasic6(TestTensorShapeBasic):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_6
class TestTupleShape1(TestTensorShapeBasic):
def init_test_func(self):
self.input = numpy.ones((5, 7)).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册