未验证 提交 318c401e 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Fix is_paddle_func not consider nn.Squential (#51763)

上级 4f32aae5
...@@ -25,6 +25,7 @@ from paddle.distributed.fleet import auto ...@@ -25,6 +25,7 @@ from paddle.distributed.fleet import auto
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.io import Dataset from paddle.io import Dataset
from paddle.jit.dy2static.utils import is_paddle_func from paddle.jit.dy2static.utils import is_paddle_func
from paddle.nn import Sequential
from paddle.static import InputSpec from paddle.static import InputSpec
batch_size = 4 batch_size = 4
...@@ -199,6 +200,9 @@ class TestIgnoreProxyLayer(unittest.TestCase): ...@@ -199,6 +200,9 @@ class TestIgnoreProxyLayer(unittest.TestCase):
self.assertFalse(is_paddle_func(proxy_layer._train)) self.assertFalse(is_paddle_func(proxy_layer._train))
self.assertFalse(is_paddle_func(proxy_layer._eval)) self.assertFalse(is_paddle_func(proxy_layer._eval))
self.assertFalse(is_paddle_func(proxy_layer._predict)) self.assertFalse(is_paddle_func(proxy_layer._predict))
# test for nn.Sequential
net = Sequential(('mlp', mlp))
self.assertFalse(is_paddle_func(net))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -262,7 +262,7 @@ def make_hashable(x, error_msg=None): ...@@ -262,7 +262,7 @@ def make_hashable(x, error_msg=None):
# NOTE(Aurelius84): Consider the following paddle inner API as common case to # NOTE(Aurelius84): Consider the following paddle inner API as common case to
# apply @to_static code transformation as usual. Because they contains # apply @to_static code transformation as usual. Because they contains
# user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer. # user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer.
AS_NOT_INNER_FUNC_LIST = set() AS_NOT_INNER_FUNC_LIST = {"paddle.nn.layer.container.Sequential"}
def as_not_paddle_func(path): def as_not_paddle_func(path):
...@@ -293,6 +293,8 @@ def is_paddle_func(func, ignore_white_list=True): ...@@ -293,6 +293,8 @@ def is_paddle_func(func, ignore_white_list=True):
if inspect.ismethod(func): if inspect.ismethod(func):
func_name = func.__self__.__class__.__name__ func_name = func.__self__.__class__.__name__
func = func.__func__ func = func.__func__
elif hasattr(func, '__class__'): # for nn.Sequential
func_name = func.__class__.__name__
m = inspect.getmodule(func) m = inspect.getmodule(func)
flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册