From d1f9a26c95d329e6137694eb80fcce256f198470 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 10 Apr 2020 18:46:51 +0800 Subject: [PATCH] Refine mechanism of calling outer function in dy2static (#23688) * Refine mechanism of calling outer function test=develop * fix typo test=develop --- .../fluid/dygraph/dygraph_to_static/utils.py | 65 ++++++------------- .../unittests/dygraph_to_static/test_dict.py | 12 ++-- .../dygraph_to_static/test_ifelse.py | 16 ++++- 3 files changed, 41 insertions(+), 52 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index c052a3525f..8e89c78e61 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -308,42 +308,6 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): return func_def_node -class ImportVisitor(gast.NodeVisitor): - """ - Visitor to parse all `import` statement. - """ - - def __init__(self, file_name): - self.root = self.file_to_ast(file_name) - self.import_statements = [] - - def transform(self): - if self.root is not None: - self.visit(self.root) - self.after_visit() - return self.import_statements - - def visit_Import(self, node): - self.import_statements.append(ast_to_source_code(node)) - return node - - def visit_ImportFrom(self, node): - self.import_statements.append(ast_to_source_code(node)) - return node - - def after_visit(self): - essential_statements = ["import paddle.fluid as fluid\n"] - new_stmts = set(essential_statements) - set(self.import_statements) - self.import_statements.extend(list(new_stmts)) - - def file_to_ast(self, file_name): - root = None - if file_name is not None: - with open(file_name) as f: - root = gast.parse(f.read()) - return root - - def index_in_list(array_list, item): try: return array_list.index(item) @@ -392,6 +356,8 @@ class RenameTransformer(gast.NodeTransformer): def ast_to_func(ast_root, dyfunc, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. + TODO: If only decorate one of inner function instead of decorating the main + function, the other inner functions are invisible for the decorated function. """ source = ast_to_source_code(ast_root) if six.PY2: @@ -400,16 +366,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): else: f = tempfile.NamedTemporaryFile( mode='w', suffix='.py', delete=False, encoding='utf-8') - # `sys.modules` is used to cache all modules and packages that avoids - # to import same modules twice by the import mechanism in python. - # We insert the import statements defined in source file into the tmpfile - # to make it easier to import external functions correctly. - source_file = inspect.getfile(dyfunc) - import_statements = ImportVisitor(source_file).transform() - import_str = "".join(import_statements) with f: module_name = os.path.basename(f.name[:-3]) - f.write(import_str) f.write(source) if delete_on_exit: @@ -420,8 +378,25 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): raise ValueError( 'Function: %s doesn\'t exist in the Module transformed from AST.' % func_name) + callable_func = getattr(module, func_name) + # After transform dygraph function into callable_func saved in tmp file, + # it lost the global variables from imported statements or defined in source file. + # Recovers the necessary variables by `__globals__`. + recover_globals_attribute(dyfunc, callable_func) + + return callable_func, f.name + + +def recover_globals_attribute(src_obj, dst_obj): + attr_name = '__globals__' + + src_globals = getattr(src_obj, attr_name, {}) + dst_globals = getattr(dst_obj, attr_name, {}) - return getattr(module, func_name), f.name + for k, v in src_globals.items(): + # ignore builtin attribute. + if not (k.startswith('__') and k.endswith('__')): + dst_globals[k] = v def ast_to_source_code(ast_node): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index 959fea7e1a..0fb09eaa4c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -99,14 +99,16 @@ class MainNetWithDict(fluid.dygraph.Layer): out = input for i in range(max_len): out = self.sub_net(out, cache) - cache = self.update_cache(cache) + cache = update_cache(cache) return out - def update_cache(self, cache): - for k, val in six.iteritems(cache): - cache[k] = fluid.layers.softmax(val) - return cache +# Test to call function defined outside of class. +def update_cache(cache): + for k, val in six.iteritems(cache): + cache[k] = fluid.layers.softmax(val) + + return cache class TestNetWithDict(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index d7122629f3..abac936f65 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -209,12 +209,18 @@ class TestDygraphIfElseNet(unittest.TestCase): self.assertTrue((self._run_dygraph() == self._run_static()).all()) +# Test to call function ahead caller. +def relu(x): + return fluid.layers.relu(x) + + def call_external_func(x, label=None): - if fluid.layers.mean(x).numpy()[0] > 5: + if fluid.layers.mean(x) < 0: x_v = x - 1 else: x_v = add_fn(x) + x_v = relu(x_v) if label is not None: loss = loss_fn(x_v, label) return loss @@ -230,17 +236,23 @@ class TestAst2FuncWithExternalFunc(TestDygraphIfElse): class NetWithExternalFunc(fluid.dygraph.Layer): @dygraph_to_static_func def forward(self, x, label=None): - if fluid.layers.mean(x).numpy()[0] > 5: + if fluid.layers.mean(x) < 0: x_v = x - 1 else: x_v = add_fn(x) + x_v = softmax(x_v) if label is not None: loss = loss_fn(x_v, label) return loss return x_v +# Test to call function behind caller. +def softmax(x): + return fluid.layers.softmax(x) + + class TestNetWithExternalFunc(TestDygraphIfElseNet): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') -- GitLab