未验证 提交 eb0ca3b7 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Fix function lookup bug in convert_call (#24567) (#24571)

* fix convert call globals_funcs test=develop

* add import statement test=develop
上级 ca41e552
...@@ -102,7 +102,10 @@ def convert_call(func): ...@@ -102,7 +102,10 @@ def convert_call(func):
if func.__name__ == '<lambda>': if func.__name__ == '<lambda>':
return func return func
try: try:
if func in func.__globals__.values(): global_funcs = set([
fn for fn in func.__globals__.values() if inspect.isfunction(fn)
])
if func in global_funcs:
converted_call = to_static_func(func) converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None) func_self = getattr(func, '__self__', None)
except AttributeError: except AttributeError:
......
...@@ -368,9 +368,10 @@ class ProgramTranslator(object): ...@@ -368,9 +368,10 @@ class ProgramTranslator(object):
prog_trans = fluid.dygraph.ProgramTranslator() prog_trans = fluid.dygraph.ProgramTranslator()
x = np.ones([1, 2]) with fluid.dygraph.guard():
x_v = prog_trans.get_output(func, x) x = np.ones([1, 2])
print(x_v.numpy()) # [[0. 0.]] x_v = prog_trans.get_output(func, x)
print(x_v.numpy()) # [[0. 0.]]
""" """
assert callable( assert callable(
...@@ -472,7 +473,7 @@ class ProgramTranslator(object): ...@@ -472,7 +473,7 @@ class ProgramTranslator(object):
x = np.ones([1, 2]) x = np.ones([1, 2])
main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x) main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
print([i.name for i in inputs]) print([i.name for i in inputs])
# ['x_0'] the feed input variable name representing x # ['feed_0'] the feed input variable name representing x
print([o.name for o in outputs]) print([o.name for o in outputs])
# ['_generated_var_4'] the fetch output variable name representing x_v # ['_generated_var_4'] the fetch output variable name representing x_v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册