From 3b58a68f693bea0f99887d079116c11b8cb268ea Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Thu, 20 Jul 2023 14:17:27 +0800 Subject: [PATCH] [Dy2St] fix `func_self` maybe a callable empty list (#55554) --- python/paddle/jit/dy2static/convert_call_func.py | 2 +- test/dygraph_to_static/test_convert_operators.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/paddle/jit/dy2static/convert_call_func.py b/python/paddle/jit/dy2static/convert_call_func.py index 2c9083f2a3a..7f086680b48 100644 --- a/python/paddle/jit/dy2static/convert_call_func.py +++ b/python/paddle/jit/dy2static/convert_call_func.py @@ -339,6 +339,6 @@ def convert_call(func): ) return func - if func_self: + if func_self is not None: converted_call = functools.partial(converted_call, func_self) return converted_call diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 8e2ea05004e..c426ebcd32c 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -25,6 +25,11 @@ class CallNotExist(paddle.nn.Layer): return paddle.nn.not_exist_api +class CallableList(list): + def __call__(self, x): + return x + + class ForwardNotExist(paddle.nn.Layer): def forward(self): return 0 @@ -51,6 +56,14 @@ class TestConvertCall(unittest.TestCase): with self.assertRaises(AttributeError): forward_not_exist() + def test_callable_list(self): + @paddle.jit.to_static + def callable_list(x, y): + callable_list = CallableList() + return callable_list(x) + y + + self.assertEqual(callable_list(1, 2), 3) + class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self): -- GitLab