未验证 提交 55730d95 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Support DictCmp and zip grammer (#32159)

* support DictCmp and zip grammar

* fix code style
上级 dabaca00
...@@ -378,6 +378,21 @@ class NameVisitor(gast.NodeVisitor): ...@@ -378,6 +378,21 @@ class NameVisitor(gast.NodeVisitor):
:param loop_node: Current loop node. :param loop_node: Current loop node.
""" """
def filter_name_nodes_from(root_node, target_var_names):
"""
Filter children with gast.Name type from node.(inclusivly)
"""
name_nodes = set()
if isinstance(root_node, gast.Name):
if node.id in target_var_names:
name_nodes.add(root_node)
for child_node in gast.walk(root_node):
if isinstance(child_node, gast.Name):
if child_node.id in target_var_names:
name_nodes.add(child_node)
return name_nodes
vars_of_list_generator = set() vars_of_list_generator = set()
target_vars_of_for_node = set() target_vars_of_for_node = set()
...@@ -412,15 +427,16 @@ class NameVisitor(gast.NodeVisitor): ...@@ -412,15 +427,16 @@ class NameVisitor(gast.NodeVisitor):
# 1.2 vars from target vars used in elt_node # 1.2 vars from target vars used in elt_node
target_var_names = {var.id for var in target_vars} target_var_names = {var.id for var in target_vars}
listcomp_node = self._get_parent_node(parent_node) comp_node = self._get_parent_node(parent_node)
elt_node = listcomp_node.elt elt_nodes = []
if isinstance(elt_node, gast.Name): if isinstance(comp_node, gast.ListComp):
if elt_node.id in target_var_names: elt_nodes.append(comp_node.elt)
vars_of_list_generator.add(elt_node) elif isinstance(comp_node, gast.DictComp):
for child_node in gast.walk(elt_node): elt_nodes.extend([comp_node.key, comp_node.value])
if isinstance(child_node, gast.Name):
if child_node.id in target_var_names: for node in elt_nodes:
vars_of_list_generator.add(child_node) vars_of_list_generator |= filter_name_nodes_from(
node, target_var_names)
# 2. Get target vars or vars from target vars used in for-loop but the for-loop is # 2. Get target vars or vars from target vars used in for-loop but the for-loop is
# 1) not the "loop_node" itself # 1) not the "loop_node" itself
......
...@@ -79,6 +79,7 @@ FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' ...@@ -79,6 +79,7 @@ FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index' FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
# FullArgSpec is valid from Python3. Defined a Namedtuple to # FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2. # to make it available in Python2.
...@@ -1012,6 +1013,9 @@ class ForNodeVisitor(object): ...@@ -1012,6 +1013,9 @@ class ForNodeVisitor(object):
# - for i, x enumerate(var|var.numpy()) # - for i, x enumerate(var|var.numpy())
# - for x in var # - for x in var
self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX) self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX)
# - created zip to list var : __for_loop_iter_zip_0
self.iter_zip_to_list_name = unique_name.generate(
FOR_ITER_ZIP_TO_LIST_PREFIX)
# - var.numpy()/var # - var.numpy()/var
# - for x in var|var.numpy() # - for x in var|var.numpy()
...@@ -1083,6 +1087,7 @@ class ForNodeVisitor(object): ...@@ -1083,6 +1087,7 @@ class ForNodeVisitor(object):
def _parse_for_stmts(self): def _parse_for_stmts(self):
init_stmts = [] init_stmts = []
init_stmts.extend(self._build_iter_node())
init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_len_assign_node()) init_stmts.append(self._build_var_len_assign_node())
...@@ -1105,6 +1110,7 @@ class ForNodeVisitor(object): ...@@ -1105,6 +1110,7 @@ class ForNodeVisitor(object):
def _parse_for_enumerate_stmts(self): def _parse_for_enumerate_stmts(self):
init_stmts = [] init_stmts = []
init_stmts.extend(self._build_iter_node())
init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_len_assign_node()) init_stmts.append(self._build_var_len_assign_node())
init_stmts.append(self._build_enum_init_node()) init_stmts.append(self._build_enum_init_node())
...@@ -1163,6 +1169,34 @@ class ForNodeVisitor(object): ...@@ -1163,6 +1169,34 @@ class ForNodeVisitor(object):
return convert_len_node return convert_len_node
def _build_iter_node(self):
"""
Process special cases for iter_node inclue:
- Case 1 (for zip):
- for i, val in enumerate(zip(x, y)) # original code:
- __for_loop_iter_zip_0 = list(zip(x, y))
- for i, val in enumerate(__for_loop_iter_zip_0)
"""
new_nodes = []
if isinstance(self.iter_node, gast.Call) and isinstance(
self.iter_node.func, gast.Name):
if self.iter_node.func.id == 'zip':
iter_var_name = ast_to_source_code(self.iter_node).strip()
zip_to_list_str = "{target} = list({value})".format(
target=self.iter_zip_to_list_name, value=iter_var_name)
zip_to_list_node = gast.parse(zip_to_list_str).body[0]
new_nodes.append(zip_to_list_node)
self.iter_node = gast.Name(
id=self.iter_zip_to_list_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
return new_nodes
def _build_enum_init_node(self): def _build_enum_init_node(self):
if self.is_for_enumerate_iter() and self.args_length != 1: if self.is_for_enumerate_iter() and self.args_length != 1:
init_value_str = ast_to_source_code(self.iter_args[1]).strip() init_value_str = ast_to_source_code(self.iter_args[1]).strip()
...@@ -1399,6 +1433,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1399,6 +1433,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for spec in src_input_specs: for spec in src_input_specs:
if spec not in desired_input_specs: if spec not in desired_input_specs:
return False return False
else: else:
for i in range(len_specs): for i in range(len_specs):
src_shape = src_input_specs[i].shape src_shape = src_input_specs[i].shape
......
...@@ -241,5 +241,39 @@ class TestDictPop(TestNetWithDict): ...@@ -241,5 +241,39 @@ class TestDictPop(TestNetWithDict):
static_result)) static_result))
class TestDictCmpInFor(unittest.TestCase):
def test_with_for(self):
def func():
pos = [1, 3]
neg = [-1, -3]
dict_val = {'minus': 0}
# test `zip` with `for`
for (x, y) in zip(pos, neg):
val = x - y
dict_val.update(
{k: val + dict_val[k]
for k, v in dict_val.items()})
return dict_val
self.assertEqual(paddle.jit.to_static(func)()['minus'], 8)
def test_with_for_enumerate(self):
def func():
pos = [1, 3]
neg = [-1, -3]
dict_val = {'minus': 0}
# test `zip` with `for`
for i, (x, y) in enumerate(zip(pos, neg)):
val = x - y
dict_val.update(
{k: val + dict_val[k]
for k, v in dict_val.items()})
return dict_val
self.assertEqual(paddle.jit.to_static(func)()['minus'], 8)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册