未验证 提交 6e1fe4f1 编写于 作者: 0 0x45f 提交者: GitHub

Support assign x.shape to dict['key'] in dy2st (#40611)

* support assign x.shape to dict['key'] in dy2st

* remove replace_dot

* refine unit test
上级 579173d8
......@@ -297,10 +297,6 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
def _update_name_to_var_shape(self, node):
def replace_dot(name):
# replace all '.' into '_'
return name.replace('.', '_')
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
value_node = node.value
......@@ -315,7 +311,6 @@ class TensorShapeTransformer(gast.NodeTransformer):
if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate(
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
......@@ -337,7 +332,6 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate(
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
......@@ -370,7 +364,6 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
static_shape_var_name = unique_name.generate(
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
......@@ -387,7 +380,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.name_to_var_shape[target_id] = static_shape_var_name
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
static_shape_var_name = unique_name.generate(
replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(static_shape_var_name).body[
0].value
static_shape_value_node = copy.deepcopy(value_node)
......
......@@ -223,6 +223,12 @@ def dyfunc_len_paddle_shape():
print(x)
def dyfunc_dict_assign_shape():
x = paddle.to_tensor([1, 2])
a = {}
a['shape'] = x.shape[0]
# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
......@@ -592,6 +598,8 @@ class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)
func = paddle.jit.to_static(dyfunc_dict_assign_shape)
self.assertEqual("__static_convert_var_shape_suffix" in func.code, True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册