diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index e2e7ed45a5c71d662737f1d2cf347de78e0227ec..86837e993d76ba3c3426f53219df200aba500524 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -42,7 +42,7 @@ class IfElseTransformer(gast.NodeTransformer): wrapper_root) self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.new_func_nodes = [] + self.new_func_nodes = {} def ast_visit(self): """ @@ -59,10 +59,10 @@ class IfElseTransformer(gast.NodeTransformer): pred_node = node.test true_func_node, false_func_node, return_name_ids = transform_if_else( node, self.root) - self.new_func_nodes += [true_func_node, false_func_node] # create layers.cond new_node = create_cond_node(return_name_ids, pred_node, true_func_node, false_func_node) + self.new_func_nodes[new_node] = [true_func_node, false_func_node] return new_node else: return node @@ -82,10 +82,28 @@ class IfElseTransformer(gast.NodeTransformer): It can be used to add the created `true_fn/false_fn` in front of the node.body before they are called in cond layer. """ - assert hasattr(node, 'body') - # add new ast.funcDef of `if/else` - if self.new_func_nodes: - node.body = self.new_func_nodes + node.body + self._insert_func_nodes(node) + + def _insert_func_nodes(self, parent_node): + """ + Defined `true_func` and `false_func` will be inserted in front of corresponding + `layers.cond` statement instead of inserting them all into body of parent node. + Because private variables of class or other external scope will be modified. + For example, `self.var_dict["key"]`. In this case, nested structure of newly + defined functions is easier to understand. + """ + if not (self.new_func_nodes and hasattr(parent_node, 'body')): + return + idx = len(parent_node.body) - 1 + while idx >= 0: + child_node = parent_node.body[idx] + if child_node in self.new_func_nodes: + parent_node.body[idx:idx] = self.new_func_nodes[child_node] + idx = idx + len(self.new_func_nodes[child_node]) - 1 + del self.new_func_nodes[child_node] + else: + self._insert_func_nodes(child_node) + idx = idx - 1 def get_new_func_nodes(self): return self.new_func_nodes diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py index 94f891c5f6e6f25ea9219fd5cbd83bf4b7a89582..357f746eb5af0e237c1aa4fc3f85f1e772a019b6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -104,7 +104,7 @@ def get_name_ids(nodes, not_name_set=None, node_black_list=None): name_ids = defaultdict(list) for node in nodes: - if node_black_list and node in node_black_list: continue + if node_black_list and node in node_black_list: break if isinstance(node, gast.AST): # In two case, the ast.Name should be filtered. # 1. Function name like `my_func` of my_func(x) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py index 0ef58d10c62344ec1f68fc5970daab6a82855c93..0dabac93f920ffb15078a3926ee523562b1b317a 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py @@ -119,5 +119,91 @@ class TestDygraphIfElse3(TestDygraphIfElse): self.dyfunc = nested_if_else +class NetWithControlFlowIf(fluid.dygraph.Layer): + def __init__(self, hidden_dim=16): + super(NetWithControlFlowIf, self).__init__() + self.hidden_dim = hidden_dim + self.fc = fluid.dygraph.Linear( + input_dim=hidden_dim, + output_dim=5, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5))) + self.alpha = 10. + self.constant_vars = {} + + @dygraph_to_static_graph + def forward(self, input): + hidden_dim = input.shape[-1] + # Plain `if` statement in Python + if hidden_dim != self.hidden_dim: + raise ValueError( + "hidden_dim {} of input is not equal to FC.weight[0]: {}" + .format(hidden_dim, self.hidden_dim)) + + self.constant_vars['bias'] = fluid.layers.fill_constant( + [5], dtype='float32', value=1) + # Control flow `if` statement + fc_out = self.fc(input) + if fluid.layers.mean(fc_out).numpy()[0] < 0: + y = fc_out + self.constant_vars['bias'] + self.constant_vars['w'] = fluid.layers.fill_constant( + [5], dtype='float32', value=10) + if y.numpy()[0] < self.alpha: + # Create new var, but is not used. + x = 10 + tmp = y * self.constant_vars['w'] + y = fluid.layers.relu(tmp) + # Nested `if/else` + if y.numpy()[-1] < self.alpha: + # Modify variable of class + self.constant_vars['w'] = fluid.layers.fill_constant( + [hidden_dim], dtype='float32', value=9) + y = fluid.layers.abs(y) + else: + tmp = fluid.layers.fill_constant( + [5], dtype='float32', value=-1) + y = y - tmp + else: + y = fc_out - self.constant_vars['bias'] + + loss = fluid.layers.mean(y) + return loss + + +class TestDygraphIfElseNet(unittest.TestCase): + """ + TestCase for the transformation from control flow `if/else` + dependent on tensor in Dygraph into Static `fluid.layers.cond`. + """ + + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.Net = NetWithControlFlowIf + + def _run_static(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + net = self.Net() + x_v = fluid.layers.assign(self.x) + # Transform into static graph + out = net(x_v) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(main_program, fetch_list=out) + return ret[0] + + def _run_dygraph(self): + with fluid.dygraph.guard(place): + net = self.Net() + x_v = fluid.dygraph.to_variable(self.x) + ret = net(x_v) + return ret.numpy() + + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + + if __name__ == '__main__': unittest.main()