From db0c1ea851b3bac249863375cca75681adc62c28 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 15 May 2020 09:10:53 +0800 Subject: [PATCH] [Dy2stat] Fix function lookup bug in convert_call (#24567) * fix convert call globals_funcs test=develop * add import statement test=develop --- .../dygraph/dygraph_to_static/convert_call_func.py | 5 ++++- .../dygraph/dygraph_to_static/program_translator.py | 10 ++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 5aa0ffb3e4..1532d5be37 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -102,7 +102,10 @@ def convert_call(func): if func.__name__ == '': return func try: - if func in func.__globals__.values(): + global_funcs = set([ + fn for fn in func.__globals__.values() if inspect.isfunction(fn) + ]) + if func in global_funcs: converted_call = to_static_func(func) func_self = getattr(func, '__self__', None) except AttributeError: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index f480df8a6f..db7d59096f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -368,9 +368,10 @@ class ProgramTranslator(object): prog_trans = fluid.dygraph.ProgramTranslator() - x = np.ones([1, 2]) - x_v = prog_trans.get_output(func, x) - print(x_v.numpy()) # [[0. 0.]] + with fluid.dygraph.guard(): + x = np.ones([1, 2]) + x_v = prog_trans.get_output(func, x) + print(x_v.numpy()) # [[0. 0.]] """ assert callable( @@ -472,7 +473,7 @@ class ProgramTranslator(object): x = np.ones([1, 2]) main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x) print([i.name for i in inputs]) - # ['x_0'] the feed input variable name representing x + # ['feed_0'] the feed input variable name representing x print([o.name for o in outputs]) # ['_generated_var_4'] the fetch output variable name representing x_v @@ -573,6 +574,7 @@ class ProgramTranslator(object): import numpy as np import paddle.fluid as fluid from paddle.fluid.dygraph import Linear + from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import ProgramTranslator class SimpleNet(fluid.dygraph.Layer): -- GitLab