未验证 提交 1507f77a 编写于 作者: L liym27 提交者: GitHub

Fix bug in convert_call because difference exists between python3 and python2....

Fix bug in convert_call because difference exists between python3 and python2. test=develop (#23966)
上级 455ed267
...@@ -30,6 +30,7 @@ import six ...@@ -30,6 +30,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
to_static_func = program_translator.get_func to_static_func = program_translator.get_func
...@@ -102,8 +103,17 @@ def convert_call(func): ...@@ -102,8 +103,17 @@ def convert_call(func):
return func return func
try: try:
if func in func.__globals__.values(): if func in func.__globals__.values():
converted_call = to_static_func(func) if six.PY3:
func_self = getattr(func, '__self__', None) source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except AttributeError: except AttributeError:
# NOTE: # NOTE:
# If func is not in __globals__, it does not need to be transformed # If func is not in __globals__, it does not need to be transformed
...@@ -116,8 +126,17 @@ def convert_call(func): ...@@ -116,8 +126,17 @@ def convert_call(func):
converted_call = None converted_call = None
elif inspect.ismethod(func): elif inspect.ismethod(func):
try: try:
func_self = getattr(func, '__self__', None) if six.PY3:
converted_call = to_static_func(func) source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError): except (IOError, OSError):
# NOTE: func may have beed decorated. # NOTE: func may have beed decorated.
converted_call = None converted_call = None
...@@ -125,9 +144,20 @@ def convert_call(func): ...@@ -125,9 +144,20 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'): elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer): if hasattr(func, 'forward') and isinstance(func, Layer):
try: try:
forward_func = to_static_func(func.forward) if six.PY3:
setattr(func, 'forward', forward_func) source_code = inspect.getsource(func.forward)
func_self = func if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception: except Exception:
# NOTE: func.forward may have beed decorated. # NOTE: func.forward may have beed decorated.
func_self = None if func_self else func_self func_self = None if func_self else func_self
...@@ -148,5 +178,4 @@ def convert_call(func): ...@@ -148,5 +178,4 @@ def convert_call(func):
if func_self: if func_self:
converted_call = functools.partial(converted_call, func_self) converted_call = functools.partial(converted_call, func_self)
return converted_call return converted_call
...@@ -25,6 +25,8 @@ SEED = 2020 ...@@ -25,6 +25,8 @@ SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
# Use a decorator to test exception
@dygraph_to_static_func
def dyfunc_with_if(x_v): def dyfunc_with_if(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5: if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1 x_v = x_v - 1
...@@ -91,6 +93,7 @@ class MyConvLayer(fluid.dygraph.Layer): ...@@ -91,6 +93,7 @@ class MyConvLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr( bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5))) initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
def forward(self, inputs): def forward(self, inputs):
y = dyfunc_with_if(inputs) y = dyfunc_with_if(inputs)
y = lambda_fun(y) y = lambda_fun(y)
...@@ -99,6 +102,7 @@ class MyConvLayer(fluid.dygraph.Layer): ...@@ -99,6 +102,7 @@ class MyConvLayer(fluid.dygraph.Layer):
@dygraph_to_static_func @dygraph_to_static_func
def dymethod(self, x_v): def dymethod(self, x_v):
x_v = fluid.layers.assign(x_v)
return x_v return x_v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册