未验证 提交 31fc3ab7 编写于 作者: A Aurelius84 提交者: GitHub

Support to use external function (#23057)

* Support to use external function test=develop

* refine the parms of ast_to_func test=develop
上级 3f371db8
...@@ -245,6 +245,5 @@ def convert_to_static(dyfunc): ...@@ -245,6 +245,5 @@ def convert_to_static(dyfunc):
root_wrapper = dygraph_to_static.get_static_ast(root) root_wrapper = dygraph_to_static.get_static_ast(root)
# Get static_func from AST # Get static_func from AST
func_name = dygraph_to_static.get_module_name() static_func, file_name = ast_to_func(root_wrapper.node, dyfunc)
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
return static_func, dygraph_to_static return static_func, dygraph_to_static
...@@ -299,6 +299,42 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): ...@@ -299,6 +299,42 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
return func_def_node return func_def_node
class ImportVisitor(gast.NodeVisitor):
"""
Visitor to parse all `import` statement.
"""
def __init__(self, file_name):
self.root = self.file_to_ast(file_name)
self.import_statements = []
def transform(self):
if self.root is not None:
self.visit(self.root)
self.after_visit()
return self.import_statements
def visit_Import(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def visit_ImportFrom(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def after_visit(self):
essential_statements = ["import paddle.fluid as fluid\n"]
new_stmts = set(essential_statements) - set(self.import_statements)
self.import_statements.extend(list(new_stmts))
def file_to_ast(self, file_name):
root = None
if file_name is not None:
with open(file_name) as f:
root = gast.parse(f.read())
return root
def index_in_list(array_list, item): def index_in_list(array_list, item):
try: try:
return array_list.index(item) return array_list.index(item)
...@@ -307,7 +343,7 @@ def index_in_list(array_list, item): ...@@ -307,7 +343,7 @@ def index_in_list(array_list, item):
return -1 return -1
def ast_to_func(ast_root, func_name, delete_on_exit=True): def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
""" """
Transform modified AST of decorated function into python callable object. Transform modified AST of decorated function into python callable object.
""" """
...@@ -318,13 +354,13 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): ...@@ -318,13 +354,13 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
else: else:
f = tempfile.NamedTemporaryFile( f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8') mode='w', suffix='.py', delete=False, encoding='utf-8')
# `sys.modules` is used to cache all modules and packages that avoids
# TODO(Aurelius84): more elegant way to transform ast into callable object # to import same modules twice by the import mechanism in python.
import_str = "import paddle\n" \ # We insert the import statements defined in source file into the tmpfile
"import paddle.fluid as fluid\n" \ # to make it easier to import external functions correctly.
"import paddle.fluid.layers as layers\n" \ source_file = inspect.getfile(dyfunc)
"import numpy as np\n" \ import_statements = ImportVisitor(source_file).transform()
"import numpy\n" import_str = "".join(import_statements)
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(import_str) f.write(import_str)
...@@ -333,6 +369,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): ...@@ -333,6 +369,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
if delete_on_exit: if delete_on_exit:
atexit.register(lambda: os.remove(f.name)) atexit.register(lambda: os.remove(f.name))
module = imp.load_source(module_name, f.name) module = imp.load_source(module_name, f.name)
func_name = dyfunc.__name__
if not hasattr(module, func_name): if not hasattr(module, func_name):
raise ValueError( raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' % 'Function: %s doesn\'t exist in the Module transformed from AST.' %
......
...@@ -18,6 +18,16 @@ import paddle.fluid as fluid ...@@ -18,6 +18,16 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph from paddle.fluid.dygraph.jit import dygraph_to_static_graph
def add_fn(x):
x = x + 1
return x
def loss_fn(x, lable):
loss = fluid.layers.cross_entropy(x, lable)
return loss
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
if fluid.layers.mean(x_v).numpy()[0] > 5: if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1 x_v = x_v - 1
......
...@@ -34,7 +34,7 @@ class TestAST2Func(unittest.TestCase): ...@@ -34,7 +34,7 @@ class TestAST2Func(unittest.TestCase):
source = inspect.getsource(func) source = inspect.getsource(func)
source = textwrap.dedent(source) source = textwrap.dedent(source)
ast_root = gast.parse(source) ast_root = gast.parse(source)
transformed_func, _ = ast_to_func(ast_root, func.__name__) transformed_func, _ = ast_to_func(ast_root, func)
return transformed_func return transformed_func
def test_ast2func(self): def test_ast2func(self):
......
...@@ -147,5 +147,43 @@ class TestDygraphIfElseNet(unittest.TestCase): ...@@ -147,5 +147,43 @@ class TestDygraphIfElseNet(unittest.TestCase):
self.assertTrue((self._run_dygraph() == self._run_static()).all()) self.assertTrue((self._run_dygraph() == self._run_static()).all())
def call_external_func(x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5:
x_v = x - 1
else:
x_v = add_fn(x)
if label is not None:
loss = loss_fn(x_v, label)
return loss
return x_v
class TestAst2FuncWithExternalFunc(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = call_external_func
class NetWithExternalFunc(fluid.dygraph.Layer):
@dygraph_to_static_graph
def forward(self, x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5:
x_v = x - 1
else:
x_v = add_fn(x)
if label is not None:
loss = loss_fn(x_v, label)
return loss
return x_v
class TestNetWithExternalFunc(TestDygraphIfElseNet):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.Net = NetWithExternalFunc
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册