未验证 提交 d1f9a26c 编写于 作者: A Aurelius84 提交者: GitHub

Refine mechanism of calling outer function in dy2static (#23688)

* Refine mechanism of calling outer function test=develop

* fix typo test=develop
上级 4773e3f5
...@@ -308,42 +308,6 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): ...@@ -308,42 +308,6 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
return func_def_node return func_def_node
class ImportVisitor(gast.NodeVisitor):
"""
Visitor to parse all `import` statement.
"""
def __init__(self, file_name):
self.root = self.file_to_ast(file_name)
self.import_statements = []
def transform(self):
if self.root is not None:
self.visit(self.root)
self.after_visit()
return self.import_statements
def visit_Import(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def visit_ImportFrom(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def after_visit(self):
essential_statements = ["import paddle.fluid as fluid\n"]
new_stmts = set(essential_statements) - set(self.import_statements)
self.import_statements.extend(list(new_stmts))
def file_to_ast(self, file_name):
root = None
if file_name is not None:
with open(file_name) as f:
root = gast.parse(f.read())
return root
def index_in_list(array_list, item): def index_in_list(array_list, item):
try: try:
return array_list.index(item) return array_list.index(item)
...@@ -392,6 +356,8 @@ class RenameTransformer(gast.NodeTransformer): ...@@ -392,6 +356,8 @@ class RenameTransformer(gast.NodeTransformer):
def ast_to_func(ast_root, dyfunc, delete_on_exit=True): def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
""" """
Transform modified AST of decorated function into python callable object. Transform modified AST of decorated function into python callable object.
TODO: If only decorate one of inner function instead of decorating the main
function, the other inner functions are invisible for the decorated function.
""" """
source = ast_to_source_code(ast_root) source = ast_to_source_code(ast_root)
if six.PY2: if six.PY2:
...@@ -400,16 +366,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -400,16 +366,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
else: else:
f = tempfile.NamedTemporaryFile( f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8') mode='w', suffix='.py', delete=False, encoding='utf-8')
# `sys.modules` is used to cache all modules and packages that avoids
# to import same modules twice by the import mechanism in python.
# We insert the import statements defined in source file into the tmpfile
# to make it easier to import external functions correctly.
source_file = inspect.getfile(dyfunc)
import_statements = ImportVisitor(source_file).transform()
import_str = "".join(import_statements)
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(import_str)
f.write(source) f.write(source)
if delete_on_exit: if delete_on_exit:
...@@ -420,8 +378,25 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -420,8 +378,25 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
raise ValueError( raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' % 'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name) func_name)
callable_func = getattr(module, func_name)
# After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`.
recover_globals_attribute(dyfunc, callable_func)
return callable_func, f.name
def recover_globals_attribute(src_obj, dst_obj):
attr_name = '__globals__'
src_globals = getattr(src_obj, attr_name, {})
dst_globals = getattr(dst_obj, attr_name, {})
return getattr(module, func_name), f.name for k, v in src_globals.items():
# ignore builtin attribute.
if not (k.startswith('__') and k.endswith('__')):
dst_globals[k] = v
def ast_to_source_code(ast_node): def ast_to_source_code(ast_node):
......
...@@ -99,10 +99,12 @@ class MainNetWithDict(fluid.dygraph.Layer): ...@@ -99,10 +99,12 @@ class MainNetWithDict(fluid.dygraph.Layer):
out = input out = input
for i in range(max_len): for i in range(max_len):
out = self.sub_net(out, cache) out = self.sub_net(out, cache)
cache = self.update_cache(cache) cache = update_cache(cache)
return out return out
def update_cache(self, cache):
# Test to call function defined outside of class.
def update_cache(cache):
for k, val in six.iteritems(cache): for k, val in six.iteritems(cache):
cache[k] = fluid.layers.softmax(val) cache[k] = fluid.layers.softmax(val)
......
...@@ -209,12 +209,18 @@ class TestDygraphIfElseNet(unittest.TestCase): ...@@ -209,12 +209,18 @@ class TestDygraphIfElseNet(unittest.TestCase):
self.assertTrue((self._run_dygraph() == self._run_static()).all()) self.assertTrue((self._run_dygraph() == self._run_static()).all())
# Test to call function ahead caller.
def relu(x):
return fluid.layers.relu(x)
def call_external_func(x, label=None): def call_external_func(x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5: if fluid.layers.mean(x) < 0:
x_v = x - 1 x_v = x - 1
else: else:
x_v = add_fn(x) x_v = add_fn(x)
x_v = relu(x_v)
if label is not None: if label is not None:
loss = loss_fn(x_v, label) loss = loss_fn(x_v, label)
return loss return loss
...@@ -230,17 +236,23 @@ class TestAst2FuncWithExternalFunc(TestDygraphIfElse): ...@@ -230,17 +236,23 @@ class TestAst2FuncWithExternalFunc(TestDygraphIfElse):
class NetWithExternalFunc(fluid.dygraph.Layer): class NetWithExternalFunc(fluid.dygraph.Layer):
@dygraph_to_static_func @dygraph_to_static_func
def forward(self, x, label=None): def forward(self, x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5: if fluid.layers.mean(x) < 0:
x_v = x - 1 x_v = x - 1
else: else:
x_v = add_fn(x) x_v = add_fn(x)
x_v = softmax(x_v)
if label is not None: if label is not None:
loss = loss_fn(x_v, label) loss = loss_fn(x_v, label)
return loss return loss
return x_v return x_v
# Test to call function behind caller.
def softmax(x):
return fluid.layers.softmax(x)
class TestNetWithExternalFunc(TestDygraphIfElseNet): class TestNetWithExternalFunc(TestDygraphIfElseNet):
def setUp(self): def setUp(self):
self.x = np.random.random([10, 16]).astype('float32') self.x = np.random.random([10, 16]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册