diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index bd89a79c805c98d6092f31573b86d8c7cea20a26..14bb54983b524ad1c09aa0d66f37b2b2aae6dbe8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -378,6 +378,21 @@ class NameVisitor(gast.NodeVisitor): :param loop_node: Current loop node. """ + def filter_name_nodes_from(root_node, target_var_names): + """ + Filter children with gast.Name type from node.(inclusivly) + """ + name_nodes = set() + if isinstance(root_node, gast.Name): + if node.id in target_var_names: + name_nodes.add(root_node) + for child_node in gast.walk(root_node): + if isinstance(child_node, gast.Name): + if child_node.id in target_var_names: + name_nodes.add(child_node) + + return name_nodes + vars_of_list_generator = set() target_vars_of_for_node = set() @@ -412,15 +427,16 @@ class NameVisitor(gast.NodeVisitor): # 1.2 vars from target vars used in elt_node target_var_names = {var.id for var in target_vars} - listcomp_node = self._get_parent_node(parent_node) - elt_node = listcomp_node.elt - if isinstance(elt_node, gast.Name): - if elt_node.id in target_var_names: - vars_of_list_generator.add(elt_node) - for child_node in gast.walk(elt_node): - if isinstance(child_node, gast.Name): - if child_node.id in target_var_names: - vars_of_list_generator.add(child_node) + comp_node = self._get_parent_node(parent_node) + elt_nodes = [] + if isinstance(comp_node, gast.ListComp): + elt_nodes.append(comp_node.elt) + elif isinstance(comp_node, gast.DictComp): + elt_nodes.extend([comp_node.key, comp_node.value]) + + for node in elt_nodes: + vars_of_list_generator |= filter_name_nodes_from( + node, target_var_names) # 2. Get target vars or vars from target vars used in for-loop but the for-loop is # 1) not the "loop_node" itself diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 624ca085ac6c2dfa5c2ca7e2894b0f2afc6f6ea2..001116a74c9cc5f149de8ab1ebd7f8f5c2f68068 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -79,6 +79,7 @@ FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' +FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip' # FullArgSpec is valid from Python3. Defined a Namedtuple to # to make it available in Python2. @@ -1012,6 +1013,9 @@ class ForNodeVisitor(object): # - for i, x enumerate(var|var.numpy()) # - for x in var self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX) + # - created zip to list var : __for_loop_iter_zip_0 + self.iter_zip_to_list_name = unique_name.generate( + FOR_ITER_ZIP_TO_LIST_PREFIX) # - var.numpy()/var # - for x in var|var.numpy() @@ -1083,6 +1087,7 @@ class ForNodeVisitor(object): def _parse_for_stmts(self): init_stmts = [] + init_stmts.extend(self._build_iter_node()) init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_var_len_assign_node()) @@ -1105,6 +1110,7 @@ class ForNodeVisitor(object): def _parse_for_enumerate_stmts(self): init_stmts = [] + init_stmts.extend(self._build_iter_node()) init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_var_len_assign_node()) init_stmts.append(self._build_enum_init_node()) @@ -1163,6 +1169,34 @@ class ForNodeVisitor(object): return convert_len_node + def _build_iter_node(self): + """ + Process special cases for iter_node inclue: + - Case 1 (for zip): + + - for i, val in enumerate(zip(x, y)) # original code: + + - __for_loop_iter_zip_0 = list(zip(x, y)) + - for i, val in enumerate(__for_loop_iter_zip_0) + """ + new_nodes = [] + if isinstance(self.iter_node, gast.Call) and isinstance( + self.iter_node.func, gast.Name): + if self.iter_node.func.id == 'zip': + iter_var_name = ast_to_source_code(self.iter_node).strip() + zip_to_list_str = "{target} = list({value})".format( + target=self.iter_zip_to_list_name, value=iter_var_name) + zip_to_list_node = gast.parse(zip_to_list_str).body[0] + new_nodes.append(zip_to_list_node) + + self.iter_node = gast.Name( + id=self.iter_zip_to_list_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + + return new_nodes + def _build_enum_init_node(self): if self.is_for_enumerate_iter() and self.args_length != 1: init_value_str = ast_to_source_code(self.iter_args[1]).strip() @@ -1399,6 +1433,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs): for spec in src_input_specs: if spec not in desired_input_specs: return False + else: for i in range(len_specs): src_shape = src_input_specs[i].shape diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index 3a7994ee67e9bc7289fff65af3837cbb874b2566..dbd3952991cfd745e3ef6ba231d15feb5020099f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -241,5 +241,39 @@ class TestDictPop(TestNetWithDict): static_result)) +class TestDictCmpInFor(unittest.TestCase): + def test_with_for(self): + def func(): + pos = [1, 3] + neg = [-1, -3] + dict_val = {'minus': 0} + # test `zip` with `for` + for (x, y) in zip(pos, neg): + val = x - y + dict_val.update( + {k: val + dict_val[k] + for k, v in dict_val.items()}) + + return dict_val + + self.assertEqual(paddle.jit.to_static(func)()['minus'], 8) + + def test_with_for_enumerate(self): + def func(): + pos = [1, 3] + neg = [-1, -3] + dict_val = {'minus': 0} + # test `zip` with `for` + for i, (x, y) in enumerate(zip(pos, neg)): + val = x - y + dict_val.update( + {k: val + dict_val[k] + for k, v in dict_val.items()}) + + return dict_val + + self.assertEqual(paddle.jit.to_static(func)()['minus'], 8) + + if __name__ == '__main__': unittest.main()