From 036121b731c9da8e0a0a75b433593b0cd43a88d1 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Sat, 4 Apr 2020 20:35:36 +0800 Subject: [PATCH] [Dy2stat] Make loop_transformer supports class variable (#23478) This CR makes two changes: 1. In old loop_transformer, if a class variable, such as "self.a, foo.bar" is a loop var, the Dy2stat will fail because `def func(self.foo)` is not legal syntax. We support class variable by renaming. 2. After https://github.com/PaddlePaddle/Paddle/pull/22892 is merged, we can support `while x < 10` in dygraph. I enable those tests in corresponding Dy2stat --- .../dygraph_to_static/loop_transformer.py | 28 +++++++++- .../fluid/dygraph/dygraph_to_static/utils.py | 46 +++++++++++++---- .../dygraph_to_static/ifelse_simple_func.py | 17 +++++++ .../dygraph_to_static/test_break_continue.py | 37 +++++++++----- .../dygraph_to_static/test_ifelse.py | 16 ++++-- .../unittests/dygraph_to_static/test_loop.py | 51 +++++++++++++++---- 6 files changed, 157 insertions(+), 38 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index a638c66af03..82400ab01b8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name +from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node @@ -36,6 +37,7 @@ WHILE_BODY_PREFIX = 'while_body' FOR_CONDITION_PREFIX = 'for_loop_condition' FOR_BODY_PREFIX = 'for_loop_body' +GENERATE_VARIABLE_PREFIX = 'generate_variable' def create_while_node(condition_name, body_name, loop_var_names): @@ -440,7 +442,8 @@ class LoopTransformer(gast.NodeTransformer): # # We need to create static variable for those variables for name in create_var_names: - new_stmts.append(create_static_variable_gast_node(name)) + if "." not in name: + new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(init_stmt) @@ -468,6 +471,11 @@ class LoopTransformer(gast.NodeTransformer): decorator_list=[], returns=None, type_comment=None) + for name in loop_var_names: + if "." in name: + rename_transformer = RenameTransformer(condition_func_node) + rename_transformer.rename( + name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body @@ -495,6 +503,11 @@ class LoopTransformer(gast.NodeTransformer): decorator_list=[], returns=None, type_comment=None) + for name in loop_var_names: + if "." in name: + rename_transformer = RenameTransformer(body_func_node) + rename_transformer.rename( + name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, @@ -521,7 +534,8 @@ class LoopTransformer(gast.NodeTransformer): # # We need to create static variable for those variables for name in create_var_names: - new_stmts.append(create_static_variable_gast_node(name)) + if "." not in name: + new_stmts.append(create_static_variable_gast_node(name)) # while x < 10 in dygraph should be convert into static tensor < 10 for name in loop_var_names: @@ -550,6 +564,11 @@ class LoopTransformer(gast.NodeTransformer): decorator_list=[], returns=None, type_comment=None) + for name in loop_var_names: + if "." in name: + rename_transformer = RenameTransformer(condition_func_node) + rename_transformer.rename( + name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(condition_func_node) new_body = node.body @@ -576,6 +595,11 @@ class LoopTransformer(gast.NodeTransformer): decorator_list=[], returns=None, type_comment=None) + for name in loop_var_names: + if "." in name: + rename_transformer = RenameTransformer(body_func_node) + rename_transformer.rename( + name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) new_stmts.append(body_func_node) while_loop_node = create_while_node(condition_func_node.name, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 581e727a5b8..c052a3525fc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -352,6 +352,43 @@ def index_in_list(array_list, item): return -1 +def create_assign_node(name, node): + """ + Creates a `gast.Assign` node by given name_id as target and node as value. + """ + targets = generate_name_node(name, ctx=gast.Store()) + assign_node = gast.Assign(targets=[targets], value=node) + return targets, assign_node + + +class RenameTransformer(gast.NodeTransformer): + def __init__(self, node): + assert isinstance( + node, gast.AST), "RenameTransformer only accepts gast.AST as input" + self.root = node + self.old_name = "" + self.new_name = "" + + def rename(self, old_name, new_name): + self.old_name = old_name + self.new_name = new_name + self.visit(self.root) + + def visit_Name(self, node): + self.generic_visit(node) + if node.id == self.old_name: + node.id = self.new_name + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + attr_full_name = get_attribute_full_name(node) + if attr_full_name == self.old_name: + new_name_node = gast.parse(self.new_name).body[0].value + return new_name_node + return node + + def ast_to_func(ast_root, dyfunc, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. @@ -399,12 +436,3 @@ def ast_to_source_code(ast_node): ast_node = gast.gast_to_ast(ast_node) source_code = astor.to_source(ast_node) return source_code - - -def create_assign_node(name, node): - """ - Creates a `gast.Assign` node by given name_id as target and node as value. - """ - targets = generate_name_node(name, ctx=gast.Store()) - assign_node = gast.Assign(targets=[targets], value=node) - return targets, assign_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 73f4e803fc2..99f5e81fd9c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -233,3 +233,20 @@ def if_with_and_or_4(x, y=None): mean_res.numpy()[0] > 0): x = x - 1 return x + + +def if_with_class_var(x, y=None): + class Foo(object): + def __init__(self): + self.a = 1 + self.b = 2 + + foo = Foo() + batch_size = fluid.layers.shape(x) + mean_res = fluid.layers.mean(x) + + if batch_size[0] > foo.a: + x = x + foo.b + else: + x = x - foo.b + return x diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 7c5d10abf4e..5c91f08f328 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -124,6 +124,26 @@ def test_for_in_else(x): return x +def while_loop_class_var(x): + class Foo(object): + def __init__(self): + self.a = 3 + self.b = 4 + self.c = 5 + + foo = Foo() + i = fluid.dygraph.to_variable(x) + while i < 10: + foo.b = fluid.layers.zeros(shape=[1], dtype='float32') + foo.c = foo.b + foo.a + i += 1 + if foo.c < 0: + continue + if foo.c > 6: + break + return foo.c + + class TestContinueInFor(unittest.TestCase): def setUp(self): self.input = np.zeros((1)).astype('int32') @@ -186,24 +206,15 @@ class TestContinueInWhile(TestContinueInFor): def init_dygraph_func(self): self.dygraph_func = test_continue_in_while - def test_transformed_static_result(self): - # TODO: while i < 10 in dygraph will be supported after PR22892 - # so currently we just assert static result. - # remove this overrided function after PR22892 is merged - static_res = self.run_static_mode() - self.assertEqual(15, static_res[0]) - class TestBreakInWhile(TestContinueInWhile): def init_dygraph_func(self): self.dygraph_func = test_break_in_while - def test_transformed_static_result(self): - # TODO: while i < 10 in dygraph will be supported after PR22892 - # so currently we just assert static result. - # remove this overrided function after PR22892 is merged - static_res = self.run_static_mode() - self.assertEqual(15, static_res[0]) + +class TestWhileLoopClassVar(TestContinueInWhile): + def init_dygraph_func(self): + self.dygraph_func = while_loop_class_var if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 1bd1fb22631..d7122629f39 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -116,7 +116,7 @@ class TestDygraphIfElse6(TestDygraphIfElse): self.dyfunc = dyfunc_ifExp_with_while -def dyfunc_ifExp_with_while2(x): +def dyfunc_ifExp(x): y = [x] def add_fn(x): @@ -128,16 +128,16 @@ def dyfunc_ifExp_with_while2(x): i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) # It will be converted into `layers.cond` as followed. - # map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y) - # `i (Tensor) == 0` is supported in dygraph. - y = map_func(lambda x: x if i == 0 else add_fn(x), y) + # map_func(lambda x: fluid.layers.cond(i==1, lambda: x, lambda: add_fn(x), y) + # `if (Tensor) == 1` is supported in dygraph. + y = map_func(lambda x: x if i == 1 else add_fn(x), y) return y[0] class TestDygraphIfElse7(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') - self.dyfunc = dyfunc_ifExp_with_while2 + self.dyfunc = dyfunc_ifExp class TestDygraphIfElseWithAndOr(TestDygraphIfElse): @@ -170,6 +170,12 @@ class TestDygraphIfElseWithAndOr4(TestDygraphIfElse): self.dyfunc = if_with_and_or_4 +class TestDygraphIfElseWithClassVar(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = if_with_class_var + + class TestDygraphIfElseNet(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 9b673bdcd1b..b64fa34500f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -78,6 +78,36 @@ def while_loop_bool_op(x): return i +def while_loop_class_var(x): + class Foo(object): + def __init__(self): + self.a = 3 + self.b = 4 + self.c = 5 + + foo = Foo() + i = fluid.dygraph.to_variable(x) + while i < 10: + foo.b = fluid.layers.zeros(shape=[1], dtype='float32') + foo.c = foo.b + foo.a + i += 1 + return foo.c + + +def for_loop_class_var(max_len): + class Foo(object): + def __init__(self): + self.a = 3 + self.b = 4 + self.c = 5 + + foo = Foo() + for i in range(max_len): + foo.b = fluid.layers.zeros(shape=[1], dtype='float32') + foo.c = foo.b + foo.a + return foo.c + + def var_create_in_for_loop(max_len): for i in range(max_len): ret = fluid.layers.zeros(shape=[3, 4, 5], dtype='float64') @@ -136,15 +166,8 @@ class TestTransformWhileLoop(unittest.TestCase): def test_ast_to_func(self): static_numpy = self._run_static() - self.assertTrue( - np.allclose( - np.full( - shape=(1), fill_value=45, dtype=np.int32), static_numpy)) - - # Enable next lines after Paddle dygraph supports while x < 10 - # - # self._run_dygraph() - # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) + dygraph_numpy = self._run_dygraph() + self.assertTrue(np.allclose(dygraph_numpy, static_numpy)) class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop): @@ -162,6 +185,11 @@ class TestWhileLoopBoolOp(TestTransformWhileLoop): self.dyfunc = while_loop_bool_op +class TestWhileLoopClassVar(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_class_var + + class TestTransformForLoop(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -192,6 +220,11 @@ class TestTransformForLoop(unittest.TestCase): self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestClassVarInForLoop(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_class_var + + class TestVarCreateInForLoop(TestTransformForLoop): def _init_dyfunc(self): self.dyfunc = var_create_in_for_loop -- GitLab