diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 4bfb310a835e20555e265d32cd572b905aebe23d..79d24c05184713d2fff6005ab9bde25af0a27570 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer): class NameVisitor(gast.NodeVisitor): - def __init__(self, end_node=None): + def __init__(self, after_node=None, end_node=None): + # The start node (exclusive) of the visitor + self.after_node = after_node # The terminate node of the visitor. self.end_node = end_node # Dict to store the names and ctxs of vars. self.name_ids = defaultdict(list) # List of current visited nodes self.ancestor_nodes = [] - # Available only when end_node is set. - self._is_finished = False + # True when in range (after_node, end_node). + self._in_range = after_node is None self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) self._def_func_names = set() def visit(self, node): """Visit a node.""" - if node == self.end_node or self._is_finished: - self._is_finished = True + if self.after_node is not None and node == self.after_node: + self._in_range = True + return + if node == self.end_node: + self._in_range = False return self.ancestor_nodes.append(node) @@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor): In above two cases, we should consider to manage the scope of vars to parsing the arguments and returned vars correctly. """ - if not self.end_node: + if not self._in_range or not self.end_node: self.generic_visit(node) + return else: before_if_name_ids = copy.deepcopy(self.name_ids) body_name_ids = self._visit_child(node.body) # If traversal process stops early in `if.body`, return the currently seen name_ids. - if self._is_finished: + if not self._in_range: self._update_name_ids(before_if_name_ids) else: else_name_ids = self._visit_child(node.orelse) # If traversal process stops early in `if.orelse`, return the currently seen name_ids. - if self._is_finished: + if not self._in_range: self._update_name_ids(before_if_name_ids) else: # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch @@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor): self.name_ids = before_if_name_ids def visit_Attribute(self, node): - if not self._is_call_func_name_node(node): + if not self._in_range or not self._is_call_func_name_node(node): self.generic_visit(node) def visit_Name(self, node): + if not self._in_range: + self.generic_visit(node) + return blacklist = {'True', 'False', 'None'} if node.id in blacklist: return if node.id in self._def_func_names: @@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor): self.name_ids[node.id].append(node.ctx) def visit_Assign(self, node): + if not self._in_range: + self.generic_visit(node) + return # Visit `value` firstly. node._fields = ('value', 'targets') self.generic_visit(node) def visit_FunctionDef(self, node): + if not self._in_range: + self.generic_visit(node) + return self._def_func_names.add(node.name) if not self.end_node: self.generic_visit(node) @@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor): self.name_ids = defaultdict(list) self.generic_visit(node) - if self._is_finished: + if not self._in_range: self._update_name_ids(before_name_ids) else: self.name_ids = before_name_ids @@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor): self.name_ids[name_id] = ctxs + self.name_ids[name_id] -def get_name_ids(nodes, end_node=None): +def get_name_ids(nodes, after_node=None, end_node=None): """ - Return all ast.Name.id of python variable in nodes. + Return all ast.Name.id of python variable in nodes range from + (after_node, end_node) exclusively. If after_node or end_node is None, the + range is unlimited. """ - name_visitor = NameVisitor(end_node) + name_visitor = NameVisitor(after_node, end_node) for node in nodes: name_visitor.visit(node) return name_visitor.name_ids @@ -434,20 +451,8 @@ def transform_if_else(node, root): parent_name_ids = get_name_ids([root], end_node=node) body_name_ids = get_name_ids(node.body) orelse_name_ids = get_name_ids(node.orelse) - # Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node. - after_ifelse_name_ids = defaultdict(list) - all_name_ids = get_name_ids([root]) - for name in all_name_ids: - before_var_names_ids = parent_name_ids.get(name, []) + \ - body_name_ids.get(name, []) + orelse_name_ids.get(name, []) - # Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load - # E.g. ctx of `x` can be [, , ] - after_var_names_ids = [ - ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids - ] - if after_var_names_ids: - after_ifelse_name_ids[name] = after_var_names_ids + after_ifelse_name_ids = get_name_ids([root], after_node=node) return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return( parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index d8d4634ae508fac81722ade1cb9d0b9d6d453089..419150345b8f4c36854767640d01a93aba5f170e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -17,6 +17,7 @@ from __future__ import print_function import numpy as np import unittest +import paddle from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator @@ -271,5 +272,75 @@ class TestNetWithExternalFunc(TestDygraphIfElseNet): self.Net = NetWithExternalFunc +class DiffModeNet1(paddle.nn.Layer): + def __init__(self, mode): + super(DiffModeNet1, self).__init__() + self.mode = mode + + @paddle.jit.to_static + def forward(self, x, y): + if self.mode == 'train': + out = x + y + elif self.mode == 'infer': + out = x - y + else: + raise ValueError('Illegal mode') + return out + + +class DiffModeNet2(paddle.nn.Layer): + def __init__(self, mode): + super(DiffModeNet2, self).__init__() + self.mode = mode + + @paddle.jit.to_static + def forward(self, x, y): + if self.mode == 'train': + out = x + y + return out + elif self.mode == 'infer': + out = x - y + return out + else: + raise ValueError('Illegal mode') + + +class TestDiffModeNet(unittest.TestCase): + """ + TestCase for the net with different modes + """ + + def setUp(self): + self.x = paddle.randn([10, 16], 'float32') + self.y = paddle.randn([10, 16], 'float32') + self.init_net() + + def init_net(self): + self.Net = DiffModeNet1 + + def _run(self, mode, to_static): + prog_trans = ProgramTranslator() + prog_trans.enable(to_static) + + net = self.Net(mode) + ret = net(self.x, self.y) + return ret.numpy() + + def test_train_mode(self): + self.assertTrue((self._run( + mode='train', to_static=True) == self._run( + mode='train', to_static=False)).all()) + + def test_infer_mode(self): + self.assertTrue((self._run( + mode='infer', to_static=True) == self._run( + mode='infer', to_static=False)).all()) + + +class TestDiffModeNet2(TestDiffModeNet): + def init_net(self): + self.Net = DiffModeNet2 + + if __name__ == '__main__': unittest.main()