未验证 提交 d1f9a26c 编写于 作者: A Aurelius84 提交者: GitHub

Refine mechanism of calling outer function in dy2static (#23688)

* Refine mechanism of calling outer function test=develop

* fix typo test=develop
上级 4773e3f5
......@@ -308,42 +308,6 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
return func_def_node
class ImportVisitor(gast.NodeVisitor):
"""
Visitor to parse all `import` statement.
"""
def __init__(self, file_name):
self.root = self.file_to_ast(file_name)
self.import_statements = []
def transform(self):
if self.root is not None:
self.visit(self.root)
self.after_visit()
return self.import_statements
def visit_Import(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def visit_ImportFrom(self, node):
self.import_statements.append(ast_to_source_code(node))
return node
def after_visit(self):
essential_statements = ["import paddle.fluid as fluid\n"]
new_stmts = set(essential_statements) - set(self.import_statements)
self.import_statements.extend(list(new_stmts))
def file_to_ast(self, file_name):
root = None
if file_name is not None:
with open(file_name) as f:
root = gast.parse(f.read())
return root
def index_in_list(array_list, item):
try:
return array_list.index(item)
......@@ -392,6 +356,8 @@ class RenameTransformer(gast.NodeTransformer):
def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
TODO: If only decorate one of inner function instead of decorating the main
function, the other inner functions are invisible for the decorated function.
"""
source = ast_to_source_code(ast_root)
if six.PY2:
......@@ -400,16 +366,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
# `sys.modules` is used to cache all modules and packages that avoids
# to import same modules twice by the import mechanism in python.
# We insert the import statements defined in source file into the tmpfile
# to make it easier to import external functions correctly.
source_file = inspect.getfile(dyfunc)
import_statements = ImportVisitor(source_file).transform()
import_str = "".join(import_statements)
with f:
module_name = os.path.basename(f.name[:-3])
f.write(import_str)
f.write(source)
if delete_on_exit:
......@@ -420,8 +378,25 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
callable_func = getattr(module, func_name)
# After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`.
recover_globals_attribute(dyfunc, callable_func)
return callable_func, f.name
def recover_globals_attribute(src_obj, dst_obj):
attr_name = '__globals__'
src_globals = getattr(src_obj, attr_name, {})
dst_globals = getattr(dst_obj, attr_name, {})
return getattr(module, func_name), f.name
for k, v in src_globals.items():
# ignore builtin attribute.
if not (k.startswith('__') and k.endswith('__')):
dst_globals[k] = v
def ast_to_source_code(ast_node):
......
......@@ -99,14 +99,16 @@ class MainNetWithDict(fluid.dygraph.Layer):
out = input
for i in range(max_len):
out = self.sub_net(out, cache)
cache = self.update_cache(cache)
cache = update_cache(cache)
return out
def update_cache(self, cache):
for k, val in six.iteritems(cache):
cache[k] = fluid.layers.softmax(val)
return cache
# Test to call function defined outside of class.
def update_cache(cache):
for k, val in six.iteritems(cache):
cache[k] = fluid.layers.softmax(val)
return cache
class TestNetWithDict(unittest.TestCase):
......
......@@ -209,12 +209,18 @@ class TestDygraphIfElseNet(unittest.TestCase):
self.assertTrue((self._run_dygraph() == self._run_static()).all())
# Test to call function ahead caller.
def relu(x):
return fluid.layers.relu(x)
def call_external_func(x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5:
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = add_fn(x)
x_v = relu(x_v)
if label is not None:
loss = loss_fn(x_v, label)
return loss
......@@ -230,17 +236,23 @@ class TestAst2FuncWithExternalFunc(TestDygraphIfElse):
class NetWithExternalFunc(fluid.dygraph.Layer):
@dygraph_to_static_func
def forward(self, x, label=None):
if fluid.layers.mean(x).numpy()[0] > 5:
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = add_fn(x)
x_v = softmax(x_v)
if label is not None:
loss = loss_fn(x_v, label)
return loss
return x_v
# Test to call function behind caller.
def softmax(x):
return fluid.layers.softmax(x)
class TestNetWithExternalFunc(TestDygraphIfElseNet):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册