From d7a7c5f0bf43bdfcc2cf8997d9fd516bf24c7681 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 11 Mar 2020 18:48:40 +0800 Subject: [PATCH] Support Simple For Range Loop in Dygraph to Static (#22867) 1. Add basic support for `for in range` loop 2. Move `test_dygraph_to_static_*` to `dygraph_to_static` dir and rename them 3. Add test case for dict in while_loop --- .../dygraph_to_static/loop_transformer.py | 237 ++++++++++++++++-- .../fluid/dygraph/dygraph_to_static/utils.py | 12 + .../dygraph_to_static/variable_trans_func.py | 12 +- .../unittests/dygraph_to_static/test_dict.py | 24 +- .../unittests/dygraph_to_static/test_loop.py | 68 ++++- .../tests/unittests/test_while_loop_op.py | 28 +++ 6 files changed, 343 insertions(+), 38 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 10f170aaefc..d92988e7866 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -21,6 +21,8 @@ from collections import defaultdict from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node +from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node @@ -29,6 +31,9 @@ __all__ = ['LoopTransformer', 'NameVisitor'] WHILE_CONDITION_PREFIX = 'while_condition' WHILE_BODY_PREFIX = 'while_body' +FOR_CONDITION_PREFIX = 'for_loop_condition' +FOR_BODY_PREFIX = 'for_loop_body' + def create_while_node(condition_name, body_name, loop_var_names): while_args = [] @@ -63,13 +68,16 @@ class NameVisitor(gast.NodeVisitor): ''' def __init__(self, root_node): - # Set of gast.Name + # Set of gast.Name or gast.Attribute for variables self.current_seen_vars = set() + # list of nodes of current visit node + self.ancestor_nodes = [] + # List of gast.While/gast.For nodes self.current_loop = [] - # Mapping from gast.While/gast.For to string name of vars - self.before_loop_vars = defaultdict(set) + # Mapping from gast.While/gast.For to variable nodes + self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) self.visit(root_node) @@ -86,13 +94,12 @@ class NameVisitor(gast.NodeVisitor): read_context = {type(gast.Load()), type(gast.AugLoad())} in_loop_vars = self.in_loop_vars[node] - in_loop_name_strs = set(name.id for name in in_loop_vars) - before_loop_vars = self.before_loop_vars[node] - before_loop_name_strs = set(name.id for name in before_loop_vars) - after_loop_vars = self.current_seen_vars - before_loop_vars - in_loop_vars - after_loop_name_strs = set( - name.id for name in after_loop_vars - if type(name.ctx) in read_context) + in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) + before_loop_body_vars = self.before_loop_body_vars[node] + before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars) + after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars + after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, + read_context) for name in in_loop_name_strs: if name in before_loop_name_strs: # If a variable is used in loop and created before loop, it @@ -106,23 +113,65 @@ class NameVisitor(gast.NodeVisitor): return loop_var_names, create_var_names def visit_Name(self, node): + if self._is_call_func_name_node(node): + self.generic_visit(node) + return + self.current_seen_vars.add(node) for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) self.generic_visit(node) + def visit(self, node): + self.ancestor_nodes.append(node) + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + ret = visitor(node) + self.ancestor_nodes.pop() + return ret + + def visit_Attribute(self, node): + if self._is_call_func_name_node(node): + return + + attr_full_name = get_attribute_full_name(node) + self.current_seen_vars.add(node) + for loop_node in self.current_loop: + self.in_loop_vars[loop_node].add(node) + # sub-nodes are visited during get_attribute_full_name and we shouldn't + # visit again + def visit_For(self, node): self.current_loop.append(node) - self.before_loop_vars[node] = copy.copy(self.current_seen_vars) + self.visit(node.target) + self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_While(self, node): self.current_loop.append(node) - self.before_loop_vars[node] = copy.copy(self.current_seen_vars) + self.visit(node.test) + self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() + def _var_nodes_to_names(self, node_set, ctx_filter_set=None): + ret = set() + for node in node_set: + if ctx_filter_set is None or type(node.ctx) in ctx_filter_set: + if isinstance(node, gast.Name): + ret.add(node.id) + elif isinstance(node, gast.Attribute): + ret.add(get_attribute_full_name(node)) + return ret + + def _is_call_func_name_node(self, node): + if self.ancestor_nodes: + parent_node = self.ancestor_nodes[-1] + if isinstance(parent_node, gast.Call) and parent_node.func == node: + return True + return False + class LoopTransformer(gast.NodeTransformer): """ @@ -140,11 +189,6 @@ class LoopTransformer(gast.NodeTransformer): def transform(self): self.visit(self.root) - def get_for_stmt_nodes(self, node): - self.generic_visit(node) - # TODO - return node - def visit(self, node): self.generic_visit(node) # All parent nodes that may contain gast.While/gast.For @@ -165,15 +209,166 @@ class LoopTransformer(gast.NodeTransformer): body_list[i:i + 1] = new_stmts i += len(new_stmts) elif isinstance(body_list[i], gast.For): - # TODO - i += 1 + new_stmts = self.get_for_stmt_nodes(body_list[i]) + body_list[i:i + 1] = new_stmts + i += len(new_stmts) else: i += 1 + def get_for_range_node(self, node): + if not isinstance(node.iter, gast.Call): + return None + if not isinstance(node.iter.func, gast.Name): + return None + if node.iter.func.id != "range": + return None + return node.iter + + def get_for_args_stmts(self, iter_name, args_list): + ''' + Returns 3 gast stmt nodes for argument. + 1. Initailize of iterate variable + 2. Condition for the loop + 3. Statement for changing of iterate variable during the loop + NOTE(TODO): Python allows to access iteration variable after loop, such + as "for i in range(10)" will create i = 9 after the loop. But using + current conversion will make i = 10. We should find a way to change it + ''' + len_range_args = len(args_list) + assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" + if len_range_args == 1: + init_stmt = get_constant_variable_node(iter_name, 0) + else: + init_stmt = gast.Assign( + targets=[ + gast.Name( + id=iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=args_list[0]) + + range_max_node = args_list[0] if len_range_args == 1 else args_list[1] + step_node = args_list[2] if len_range_args == 3 else gast.Constant( + value=1, kind=None) + + cond_stmt = gast.Compare( + left=gast.BinOp( + left=gast.Name( + id=iter_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + op=gast.Add(), + right=step_node), + ops=[gast.LtE()], + comparators=[range_max_node]) + + change_stmt = gast.AugAssign( + target=gast.Name( + id=iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=step_node) + + return init_stmt, cond_stmt, change_stmt + + def get_for_stmt_nodes(self, node): + # TODO: consider for - else in python + if not self.name_visitor.is_control_flow_loop(node): + return [node] + + # TODO: support non-range case + range_call_node = self.get_for_range_node(node) + if range_call_node is None: + return [node] + + if not isinstance(node.target, gast.Name): + return [node] + iter_var_name = node.target.id + + init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts( + iter_var_name, range_call_node.args) + + loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( + node) + new_stmts = [] + # Python can create variable in loop and use it out of loop, E.g. + # + # for x in range(10): + # y += x + # print(x) # x = 10 + # + # We need to create static variable for those variables + for name in create_var_names: + new_stmts.append(create_static_variable_gast_node(name)) + + new_stmts.append(init_stmt) + + # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10 + for name in loop_var_names: + new_stmts.append(to_static_variable_gast_node(name)) + + condition_func_node = gast.FunctionDef( + name=unique_name.generate(FOR_CONDITION_PREFIX), + args=gast.arguments( + args=[ + gast.Name( + id=name, + ctx=gast.Param(), + annotation=None, + type_comment=None) for name in loop_var_names + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=[gast.Return(value=cond_stmt)], + decorator_list=[], + returns=None, + type_comment=None) + new_stmts.append(condition_func_node) + + new_body = node.body + new_body.append(change_stmt) + new_body.append( + gast.Return(value=generate_name_node( + loop_var_names, ctx=gast.Load()))) + body_func_node = gast.FunctionDef( + name=unique_name.generate(FOR_BODY_PREFIX), + args=gast.arguments( + args=[ + gast.Name( + id=name, + ctx=gast.Param(), + annotation=None, + type_comment=None) for name in loop_var_names + ], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=new_body, + decorator_list=[], + returns=None, + type_comment=None) + new_stmts.append(body_func_node) + + while_loop_node = create_while_node(condition_func_node.name, + body_func_node.name, loop_var_names) + new_stmts.append(while_loop_node) + + return new_stmts + def get_while_stmt_nodes(self, node): # TODO: consider while - else in python - # self.generic_visit(node) - if not self.name_visitor.is_control_flow_loop(node): return [node] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index fba46f16ee0..d4cf0143f57 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -215,6 +215,18 @@ def create_api_shape_node(tensor_shape_node): return api_shape_node +def get_constant_variable_node(name, value, shape=[1], dtype='int64'): + return gast.parse('%s = fluid.layers.fill_constant(%s, "%s", %s)' % + (name, str(shape), dtype, str(value))) + + +def get_attribute_full_name(node): + assert isinstance( + node, + gast.Attribute), "Input non-Attribute node to get attribute full name" + return astor.to_source(gast.gast_to_ast(node)).strip() + + def generate_name_node(name_ids, ctx=gast.Load()): """ Generate list or gast.Tuple of ast.Name for Return statement. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 621299ddda2..c2979e86e77 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -14,6 +14,7 @@ from __future__ import print_function +import six import gast from paddle.fluid.layers import fill_constant @@ -39,8 +40,15 @@ def to_static_variable(x): ''' if isinstance(x, bool): return fill_constant(shape=[1], dtype='bool', value=x) - if isinstance(x, int): - return fill_constant(shape=[1], dtype='int64', value=x) if isinstance(x, float): return fill_constant(shape=[1], dtype='float64', value=x) + + if six.PY2: + if isinstance(x, int): + return fill_constant(shape=[1], dtype='int32', value=x) + if isinstance(x, long): + return fill_constant(shape=[1], dtype='int64', value=x) + else: + if isinstance(x, int): + return fill_constant(shape=[1], dtype='int64', value=x) return x diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index 2d1a46d3ae5..7ffbabd75b0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -59,7 +59,17 @@ class SubNetWithDict(fluid.dygraph.Layer): cache_k, cache_v = cache["k"], cache["v"] k = 0.1 * cache_k + k v = 0.2 * cache_v + v - cache["k"], cache["v"] = k, v + # TODO: currently while_loop can have a dict as loop_vars, but + # to change the value in a dict, you have to use layers.assign + # because cache["k"] = k is putting k in dict without building + # network. So we cannot write: + # + # cache["k"], cache["v"] = k, v + # + # we have to support this kind of dict in loop in the future. + # For example, automatically change = to assign in AutoTracer + fluid.layers.assign(k, cache["k"]) + fluid.layers.assign(v, cache["v"]) weight = fluid.layers.matmul(x=q, y=k, transpose_y=True) weight = fluid.layers.softmax(weight) @@ -94,12 +104,20 @@ class MainNetWithDict(fluid.dygraph.Layer): for i in range(max_len): out = self.sub_net(out, cache) cache = self.update_cache(cache) - return out def update_cache(self, cache): for k, val in six.iteritems(cache): - cache[k] = fluid.layers.softmax(val) + # TODO: currently while_loop can have a dict as loop_vars, but + # to change the value in a dict, you have to use layers.assign + # because cache["k"] = k is putting k in dict without building + # network. So we cannot write: + # + # cache[k] = fluid.layers.softmax(val) + # + # we have to support this kind of dict in loop in the future. + # For example, automatically change = to assign in AutoTracer + fluid.layers.assign(fluid.layers.softmax(val), cache[k]) return cache diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index c9594ff1714..3bb7a356288 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -35,20 +35,34 @@ def while_loop_dyfunc(x): return i +def for_loop_dyfunc(max_len): + for i in range(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + class TestNameVisitor(unittest.TestCase): + def setUp(self): + self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] + self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])] + self.create_var_names = [set(), set(["ret"])] + def test_loop_vars(self): - test_func = inspect.getsource(while_loop_dyfunc) - gast_root = gast.parse(test_func) - name_visitor = NameVisitor(gast_root) - for node in gast.walk(gast_root): - if isinstance(node, gast.While): - loop_var_names, create_var_names = name_visitor.get_loop_var_names( - node) - self.assertEqual(loop_var_names, set(["i", "x"])) - self.assertEqual(create_var_names, set()) - - -class TestTransformWhile(unittest.TestCase): + for i in range(len(self.loop_funcs)): + func = self.loop_funcs[i] + test_func = inspect.getsource(func) + gast_root = gast.parse(test_func) + name_visitor = NameVisitor(gast_root) + for node in gast.walk(gast_root): + if isinstance(node, (gast.While, gast.For)): + loop_var_names, create_var_names = name_visitor.get_loop_var_names( + node) + self.assertEqual(loop_var_names, self.loop_var_names[i]) + self.assertEqual(create_var_names, self.create_var_names[i]) + + +class TestTransformWhileLoop(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() @@ -83,5 +97,35 @@ class TestTransformWhile(unittest.TestCase): # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestTransformForLoop(unittest.TestCase): + def setUp(self): + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.len = 100 + + def _run_static(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + static_func = dygraph_to_static_graph(for_loop_dyfunc) + out = static_func(self.len) + exe = fluid.Executor(self.place) + ret = exe.run(main_program, fetch_list=out) + return ret + + def _run_dygraph(self): + with fluid.dygraph.guard(self.place): + ret = for_loop_dyfunc(self.len) + return ret.numpy() + + def test_ast_to_func(self): + static_numpy = self._run_static() + self.assertTrue( + np.allclose( + np.full( + shape=(1), fill_value=2, dtype=np.int32), static_numpy)) + self._run_dygraph() + self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index 6d86f604a1e..4c8e1217e3f 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -77,6 +77,34 @@ class TestApiWhileLoop(unittest.TestCase): data = np.add(data, data_one) self.assertTrue(np.allclose(np.asarray(res[1]), data)) + def test_var_dict(self): + def cond(i, ten, test_dict): + return layers.less_than(i, ten) + + def body(i, ten, test_dict): + layers.assign(i, test_dict["test_key"]) + i = layers.increment(i) + return [i, ten, test_dict] + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + i = layers.zeros(shape=[1], dtype='int64') + ten = layers.fill_constant(shape=[1], dtype='int64', value=10) + test_data = layers.fill_constant(shape=[1], dtype='int64', value=0) + test_dict = {"test_key": test_data} + i, ten, test_dict = layers.while_loop(cond, body, + [i, ten, test_dict]) + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + res = exe.run(main_program, fetch_list=[test_dict["test_key"]]) + self.assertTrue( + np.allclose( + np.asarray(res[0]), + np.full( + shape=(1), fill_value=9, dtype=np.int64))) + class TestApiWhileLoop_Nested(unittest.TestCase): def test_nested_net(self): -- GitLab