未验证 提交 1507f77a 编写于 作者: L liym27 提交者: GitHub

Fix bug in convert_call because difference exists between python3 and python2....

Fix bug in convert_call because difference exists between python3 and python2. test=develop (#23966)
上级 455ed267
......@@ -30,6 +30,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
to_static_func = program_translator.get_func
......@@ -102,8 +103,17 @@ def convert_call(func):
return func
try:
if func in func.__globals__.values():
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
......@@ -116,8 +126,17 @@ def convert_call(func):
converted_call = None
elif inspect.ismethod(func):
try:
func_self = getattr(func, '__self__', None)
converted_call = to_static_func(func)
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have beed decorated.
converted_call = None
......@@ -125,9 +144,20 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
if six.PY3:
source_code = inspect.getsource(func.forward)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
# NOTE: func.forward may have beed decorated.
func_self = None if func_self else func_self
......@@ -148,5 +178,4 @@ def convert_call(func):
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
......@@ -25,6 +25,8 @@ SEED = 2020
np.random.seed(SEED)
# Use a decorator to test exception
@dygraph_to_static_func
def dyfunc_with_if(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
......@@ -91,6 +93,7 @@ class MyConvLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
def forward(self, inputs):
y = dyfunc_with_if(inputs)
y = lambda_fun(y)
......@@ -99,6 +102,7 @@ class MyConvLayer(fluid.dygraph.Layer):
@dygraph_to_static_func
def dymethod(self, x_v):
x_v = fluid.layers.assign(x_v)
return x_v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册