diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 7c31093568e79a3f0a7446862be5dda5e1f01b66..7c2974776f2d05ca354944aa3554bb9a01493637 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -26,7 +26,6 @@ from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node -from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node @@ -34,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' @@ -55,14 +55,12 @@ class IfElseTransformer(gast.NodeTransformer): wrapper_root) self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root) - self.new_func_nodes = {} def transform(self): """ Main function to transform AST. """ self.visit(self.root) - self.after_visit(self.root) def visit_If(self, node): if_condition_visitor = IfConditionVisitor(node.test, @@ -71,14 +69,14 @@ class IfElseTransformer(gast.NodeTransformer): self.generic_visit(node) if need_transform: pred_node, new_assign_nodes = if_condition_visitor.transform() - true_func_node, false_func_node, return_name_ids = transform_if_else( + new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else( node, self.root) # create layers.cond - new_node = create_cond_node(return_name_ids, pred_node, - true_func_node, false_func_node) - self.new_func_nodes[new_node] = [true_func_node, false_func_node - ] + new_assign_nodes - return new_node + cond_node = create_cond_node(return_name_ids, pred_node, + true_func_node, false_func_node) + + return new_vars_stmts + [true_func_node, false_func_node + ] + new_assign_nodes + [cond_node] else: return node @@ -117,43 +115,6 @@ class IfElseTransformer(gast.NodeTransformer): else: return node - def after_visit(self, node): - """ - This function will add some postprocessing operations with node. - It can be used to add the created `true_fn/false_fn` in front of - the node.body before they are called in cond layer. - """ - self._insert_func_nodes(node) - - def _insert_func_nodes(self, node): - """ - Defined `true_func` and `false_func` will be inserted in front of corresponding - `layers.cond` statement instead of inserting them all into body of parent node. - Because private variables of class or other external scope will be modified. - For example, `self.var_dict["key"]`. In this case, nested structure of newly - defined functions is easier to understand. - """ - if not self.new_func_nodes: - return - idx = -1 - if isinstance(node, list): - idx = len(node) - 1 - elif isinstance(node, gast.AST): - for _, child in gast.iter_fields(node): - self._insert_func_nodes(child) - while idx >= 0: - child_node = node[idx] - if child_node in self.new_func_nodes: - node[idx:idx] = self.new_func_nodes[child_node] - idx = idx + len(self.new_func_nodes[child_node]) - 1 - del self.new_func_nodes[child_node] - else: - self._insert_func_nodes(child_node) - idx = idx - 1 - - def get_new_func_nodes(self): - return self.new_func_nodes - def merge_multi_assign_nodes(assign_nodes): """ @@ -467,10 +428,6 @@ class NameVisitor(gast.NodeVisitor): else: self.name_ids = before_name_ids - def visit_Return(self, node): - # Ignore the vars in return - return - def _visit_child(self, node): self.name_ids = defaultdict(list) if isinstance(node, list): @@ -553,25 +510,63 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): return arguments -def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict): +def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, + after_ifelse_vars_dict): """ Find out the ast.Name list of output by analyzing node's AST information. - Following conditions should be satisfied while determining whether a variable is a return value: - 1. the var in parent scope is modified in if/else node. - 2. new var is both created in if and else node. + One of the following conditions should be satisfied while determining whether a variable is a return value: + 1. the var in parent scope is modified in If.body or If.orelse node. + 2. new var is both created in If.body and If.orelse node. + 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. - If different var is modified in if and else node, it should add the var in return_ids - of different node. For example: - x, y = 5, 10 - if x > 4: - x = x+1 - z = x*x - else: - y = y - 1 - z = y*y + x, y = 5, 10 + if x > 4: + x = x+1 + z = x*x + q = 10 + else: + y = y - 1 + z = y*y + m = 20 + n = 20 + + print(q) + n = 30 + print(n) + + + The return_ids are (x, y, z, q) for `If.body` and `If.orelse`node, because + 1. x is modified in If.body node, + 2. y is modified in If.body node, + 3. z is both created in If.body and If.orelse node, + 4. q is created only in If.body, and it is used by `print(q)` as gast.Load. + Note: + After transformed, q and z are created in parent scope. For example, + + x, y = 5, 10 + q = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='q', shape=[-1], dtype='float32') + z = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='z', shape=[-1], dtype='float32') + + def true_func(x, y, q): + x = x+1 + z = x*x + q = 10 + return x,y,z,q + + def false_func(x, y, q): + y = y - 1 + z = y*y + m = 20 + n = 20 + return x,y,z,q + + x,y,z,q = fluid.layers.cond(x>4, lambda: true_func(x, y), lambda: false_func(x, y, q)) + + m and n are not in return_ids, because + 5. m is created only in If.orelse, but it is not used after gast.If node. + 6. n is created only in If.orelse, and it is used by `n = 30` and `print(n)`, but it is not used as gast.Load firstly but gast.Store . - The return_ids should be (x, y, z) for `if` and `else`node. """ def _is_return_var(ctxs): @@ -587,57 +582,112 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict): vars.append(k) return vars - def _candidate_vars(child_dict, parent_dict): + def _modified_vars(child_dict, parent_dict): return set([ var for var in _vars_with_store(child_dict) if var in parent_dict ]) - # 1. the var in parent_ids is modified in if/else node. - if_candidate_vars = _candidate_vars(if_vars_dict, parent_vars_dict) - else_candidate_vars = _candidate_vars(else_vars_dict, parent_vars_dict) - - # 2. new var is both created in if and else node. - if_new_vars = set([ + def _vars_loaded_before_store(ids_dict): + new_dict = defaultdict(list) + for k, ctxs in ids_dict.items(): + for ctx in ctxs: + if isinstance(ctx, gast.Load): + new_dict[k].append(ctx) + elif isinstance(ctx, gast.Store): + break + return new_dict + + # modified vars + body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict) + orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict) + modified_vars = body_modified_vars | orelse_modified_vars + + # new vars + body_new_vars = set([ var for var in _vars_with_store(if_vars_dict) if var not in parent_vars_dict ]) - else_new_vars = set([ + orelse_new_vars = set([ var for var in _vars_with_store(else_vars_dict) if var not in parent_vars_dict ]) - new_vars = if_new_vars & else_new_vars + new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars + new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars + + # 1. the var in parent scope is modified in If.body or If.orelse node. + modified_vars_from_parent = modified_vars - new_vars_in_body_or_orelse + + # 2. new var is both created in If.body and If.orelse node. + new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars - # generate return_ids of if/else node. - modified_vars = if_candidate_vars | else_candidate_vars - return_ids = list(modified_vars | new_vars) + # 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. + used_vars_after_ifelse = set( + [var for var in _vars_loaded_before_store(after_ifelse_vars_dict)]) + new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse + + # 4. generate return_ids of if/else node. + return_ids = list(modified_vars_from_parent | new_vars_in_body_and_orelse | + new_vars_to_create) return_ids.sort() - return return_ids, list(modified_vars - new_vars) + return return_ids, modified_vars_from_parent, new_vars_to_create def transform_if_else(node, root): """ Transform ast.If into control flow statement of Paddle static graph. """ + # TODO(liym27): Consider variable like `self.a` modified in if/else node. parent_name_ids = get_name_ids([root], end_node=node) - if_name_ids = get_name_ids(node.body) - else_name_ids = get_name_ids(node.orelse) - - return_name_ids, modified_name_ids = parse_cond_return( - parent_name_ids, if_name_ids, else_name_ids) + body_name_ids = get_name_ids(node.body) + orelse_name_ids = get_name_ids(node.orelse) + + # Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node. + after_ifelse_name_ids = defaultdict(list) + all_name_ids = get_name_ids([root]) + for name in all_name_ids: + before_var_names_ids = parent_name_ids.get(name, []) + \ + body_name_ids.get(name, []) + orelse_name_ids.get(name, []) + # Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load + # E.g. ctx of `x` can be [, , ] + after_var_names_ids = [ + ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids + ] + if after_var_names_ids: + after_ifelse_name_ids[name] = after_var_names_ids + + return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return( + parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids) + + # NOTE: Python can create variable only in if body or only in else body, and use it out of if/else. + # E.g. + # + # if x > 5: + # a = 10 + # print(a) + # + # Create static variable for those variables + create_new_vars_in_parent_stmts = [] + for name in new_vars_to_create: + # NOTE: Consider variable like `self.a` modified in if/else node. + if "." not in name: + create_new_vars_in_parent_stmts.append( + create_static_variable_gast_node(name)) + + modified_name_ids = modified_name_ids_from_parent | new_vars_to_create true_func_node = create_funcDef_node( node.body, name=unique_name.generate(TRUE_FUNC_PREFIX), - input_args=parse_cond_args(if_name_ids, modified_name_ids), + input_args=parse_cond_args(body_name_ids, modified_name_ids), return_name_ids=return_name_ids) false_func_node = create_funcDef_node( node.orelse, name=unique_name.generate(FALSE_FUNC_PREFIX), - input_args=parse_cond_args(else_name_ids, modified_name_ids), + input_args=parse_cond_args(orelse_name_ids, modified_name_ids), return_name_ids=return_name_ids) - return true_func_node, false_func_node, return_name_ids + return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids def create_cond_node(return_name_ids, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index b2dbd6cc5971b6ae86d2febed44adc65bc4f3b73..f024b8c3dcc9c41271488e28c5a40143749c8b43 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -52,6 +52,51 @@ def dyfunc_with_if_else2(x, col=100): return y +def dyfunc_with_if_else3(x): + # Create new var in parent scope, return it in true_fn and false_fn. + # The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. + # The transformed code: + """ + q = fluid.dygraph.dygraph_to_static.variable_trans_func. + data_layer_not_check(name='q', shape=[-1], dtype='float32') + z = fluid.dygraph.dygraph_to_static.variable_trans_func. + data_layer_not_check(name='z', shape=[-1], dtype='float32') + + def true_fn_0(q, x, y): + x = x + 1 + z = x + 2 + q = x + 3 + return q, x, y, z + + def false_fn_0(q, x, y): + y = y + 1 + z = x - 2 + m = x + 2 + n = x + 3 + return q, x, y, z + q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda : + fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(q, x, y), + lambda : fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(q, + x, y)) + """ + y = x + 1 + # NOTE: x_v[0] < 5 is True + if fluid.layers.mean(x).numpy()[0] < 5: + x = x + 1 + z = x + 2 + q = x + 3 + else: + y = y + 1 + z = x - 2 + m = x + 2 + n = x + 3 + + q = q + 1 + n = q + 2 + x = n + return x + + def nested_if_else(x_v): batch_size = 16 feat_size = x_v.shape[-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 0363e32e030d76c1fc4068d86c28c362a0085773..5656c7fce81e3957b3d0318a2edd27483494ba6a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -64,18 +64,24 @@ class TestDygraphIfElse2(TestDygraphIfElse): class TestDygraphIfElse3(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_with_if_else3 + + +class TestDygraphNestedIfElse(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else -class TestDygraphIfElse4(TestDygraphIfElse): +class TestDygraphNestedIfElse2(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else_2 -class TestDygraphIfElse5(TestDygraphIfElse): +class TestDygraphNestedIfElse3(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else_3 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py index 1efa4961b7437b3d262190be74ed7a0c06ed98d3..e5dc3c219dff54658af7bd81aec2aadcc3b50e9d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -34,7 +34,7 @@ class TestGetNameIds(unittest.TestCase): def test_fn(x): return x+1 """ - self.all_name_ids = {'x': [gast.Param()]} + self.all_name_ids = {'x': [gast.Param(), gast.Load()]} def test_get_name_ids(self): source = textwrap.dedent(self.source) @@ -82,6 +82,7 @@ class TestGetNameIds2(TestGetNameIds): gast.Load(), gast.Store(), gast.Store(), + gast.Load(), ] } @@ -113,6 +114,7 @@ class TestGetNameIds3(TestGetNameIds): gast.Store(), gast.Load(), gast.Store(), + gast.Load(), ] }