diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index eaa16575c6683112d9d3bcb8bb476b091cf3d566..004aa97bfaf7ca77b9a5cfaab365168a44fe069c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -245,6 +245,5 @@ def convert_to_static(dyfunc): root_wrapper = dygraph_to_static.get_static_ast(root) # Get static_func from AST - func_name = dygraph_to_static.get_module_name() - static_func, file_name = ast_to_func(root_wrapper.node, func_name) + static_func, file_name = ast_to_func(root_wrapper.node, dyfunc) return static_func, dygraph_to_static diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 598202918486b9d9c71ac2fba515e7558983ae20..66e45780e6ed9e4bc8e84ecf938a719ac9f8cb02 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -299,6 +299,42 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): return func_def_node +class ImportVisitor(gast.NodeVisitor): + """ + Visitor to parse all `import` statement. + """ + + def __init__(self, file_name): + self.root = self.file_to_ast(file_name) + self.import_statements = [] + + def transform(self): + if self.root is not None: + self.visit(self.root) + self.after_visit() + return self.import_statements + + def visit_Import(self, node): + self.import_statements.append(ast_to_source_code(node)) + return node + + def visit_ImportFrom(self, node): + self.import_statements.append(ast_to_source_code(node)) + return node + + def after_visit(self): + essential_statements = ["import paddle.fluid as fluid\n"] + new_stmts = set(essential_statements) - set(self.import_statements) + self.import_statements.extend(list(new_stmts)) + + def file_to_ast(self, file_name): + root = None + if file_name is not None: + with open(file_name) as f: + root = gast.parse(f.read()) + return root + + def index_in_list(array_list, item): try: return array_list.index(item) @@ -307,7 +343,7 @@ def index_in_list(array_list, item): return -1 -def ast_to_func(ast_root, func_name, delete_on_exit=True): +def ast_to_func(ast_root, dyfunc, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. """ @@ -318,13 +354,13 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): else: f = tempfile.NamedTemporaryFile( mode='w', suffix='.py', delete=False, encoding='utf-8') - - # TODO(Aurelius84): more elegant way to transform ast into callable object - import_str = "import paddle\n" \ - "import paddle.fluid as fluid\n" \ - "import paddle.fluid.layers as layers\n" \ - "import numpy as np\n" \ - "import numpy\n" + # `sys.modules` is used to cache all modules and packages that avoids + # to import same modules twice by the import mechanism in python. + # We insert the import statements defined in source file into the tmpfile + # to make it easier to import external functions correctly. + source_file = inspect.getfile(dyfunc) + import_statements = ImportVisitor(source_file).transform() + import_str = "".join(import_statements) with f: module_name = os.path.basename(f.name[:-3]) f.write(import_str) @@ -333,6 +369,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): if delete_on_exit: atexit.register(lambda: os.remove(f.name)) module = imp.load_source(module_name, f.name) + func_name = dyfunc.__name__ if not hasattr(module, func_name): raise ValueError( 'Function: %s doesn\'t exist in the Module transformed from AST.' % diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 8655bcab4471d6870f42e5304bbe4734060042c5..75ee2190bcdee92ad8ac82ef95bdfec9509859d5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -18,6 +18,16 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.jit import dygraph_to_static_graph +def add_fn(x): + x = x + 1 + return x + + +def loss_fn(x, lable): + loss = fluid.layers.cross_entropy(x, lable) + return loss + + def dyfunc_with_if_else(x_v, label=None): if fluid.layers.mean(x_v).numpy()[0] > 5: x_v = x_v - 1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py index a048e20799db401afe31c6cc36084afb5618b9f7..62b6ac171a4c96089bade7b24ec8f2d02c9d94da 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py @@ -34,7 +34,7 @@ class TestAST2Func(unittest.TestCase): source = inspect.getsource(func) source = textwrap.dedent(source) ast_root = gast.parse(source) - transformed_func, _ = ast_to_func(ast_root, func.__name__) + transformed_func, _ = ast_to_func(ast_root, func) return transformed_func def test_ast2func(self): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 2502064b2682bdf57b4e2871f38ca2b05b86e846..57194c29fa2be8aee650078ca3d3c1b216848a43 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -147,5 +147,43 @@ class TestDygraphIfElseNet(unittest.TestCase): self.assertTrue((self._run_dygraph() == self._run_static()).all()) +def call_external_func(x, label=None): + if fluid.layers.mean(x).numpy()[0] > 5: + x_v = x - 1 + else: + x_v = add_fn(x) + + if label is not None: + loss = loss_fn(x_v, label) + return loss + return x_v + + +class TestAst2FuncWithExternalFunc(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = call_external_func + + +class NetWithExternalFunc(fluid.dygraph.Layer): + @dygraph_to_static_graph + def forward(self, x, label=None): + if fluid.layers.mean(x).numpy()[0] > 5: + x_v = x - 1 + else: + x_v = add_fn(x) + + if label is not None: + loss = loss_fn(x_v, label) + return loss + return x_v + + +class TestNetWithExternalFunc(TestDygraphIfElseNet): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.Net = NetWithExternalFunc + + if __name__ == '__main__': unittest.main()