diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 9c338546e233323ef50cb78c61907f803bf8cb17..4bfb310a835e20555e265d32cd572b905aebe23d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -245,23 +245,51 @@ def get_name_ids(nodes, end_node=None): return name_visitor.name_ids -def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): +def parse_cond_args(parent_ids_dict, + var_ids_dict, + modified_ids_dict=None, + ctx=gast.Load): """ Find out the ast.Name.id list of input by analyzing node's AST information. """ - name_ids = [ + # 1. filter the var fit the ctx + arg_name_ids = [ var_id for var_id, var_ctx in six.iteritems(var_ids_dict) if isinstance(var_ctx[0], ctx) ] - if return_ids: - new_args = set(return_ids) - set(name_ids) - name_ids.extend(list(new_args)) - name_ids.sort() + + # 2. args should contain modified var ids in if-body or else-body + # case: + # + # ``` + # if b < 1: + # z = y + # else: + # z = x + # ``` + # + # In the above case, `z` should be in the args of cond() + if modified_ids_dict: + arg_name_ids = set(arg_name_ids) | set(modified_ids_dict) + + # 3. args should not contain the vars not in parent ids + # case : + # + # ``` + # x = 1 + # if x > y: + # z = [v for v in range(i)] + # ``` + # + # In the above case, `v` should not be in the args of cond() + arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict)) + + arg_name_ids.sort() args = [ gast.Name( id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) - for name_id in name_ids + for name_id in arg_name_ids ] arguments = gast.arguments( args=args, @@ -412,7 +440,7 @@ def transform_if_else(node, root): all_name_ids = get_name_ids([root]) for name in all_name_ids: before_var_names_ids = parent_name_ids.get(name, []) + \ - body_name_ids.get(name, []) + orelse_name_ids.get(name, []) + body_name_ids.get(name, []) + orelse_name_ids.get(name, []) # Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load # E.g. ctx of `x` can be [, , ] after_var_names_ids = [ @@ -444,12 +472,14 @@ def transform_if_else(node, root): true_func_node = create_funcDef_node( node.body, name=unique_name.generate(TRUE_FUNC_PREFIX), - input_args=parse_cond_args(body_name_ids, modified_name_ids), + input_args=parse_cond_args(parent_name_ids, body_name_ids, + modified_name_ids), return_name_ids=return_name_ids) false_func_node = create_funcDef_node( node.orelse, name=unique_name.generate(FALSE_FUNC_PREFIX), - input_args=parse_cond_args(orelse_name_ids, modified_name_ids), + input_args=parse_cond_args(parent_name_ids, orelse_name_ids, + modified_name_ids), return_name_ids=return_name_ids) return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 34d7b59a9b487b2424be8c9368b7dcc6dcbf59c0..b343c54d6b1ee640a325ba719aa483e193f8cbf6 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -14,6 +14,7 @@ from __future__ import print_function +import paddle import paddle.fluid as fluid @@ -99,6 +100,16 @@ def dyfunc_with_if_else3(x): return x +def dyfunc_with_if_else_with_list_geneator(x): + if 10 > 5: + y = paddle.add_n( + [paddle.full( + shape=[2], fill_value=v) for v in range(5)]) + else: + y = x + return y + + def nested_if_else(x_v): batch_size = 16 feat_size = x_v.shape[-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 5656c7fce81e3957b3d0318a2edd27483494ba6a..d8d4634ae508fac81722ade1cb9d0b9d6d453089 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -69,6 +69,12 @@ class TestDygraphIfElse3(TestDygraphIfElse): self.dyfunc = dyfunc_with_if_else3 +class TestDygraphIfElseWithListGenerator(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_with_if_else_with_list_geneator + + class TestDygraphNestedIfElse(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32')