diff --git a/python/paddle/jit/dy2static/convert_call_func.py b/python/paddle/jit/dy2static/convert_call_func.py index 2c9083f2a3a76f29df27fce4e9d804aeb3646ae2..7f086680b48f6ce3e7294be4851cb2372afc0cdf 100644 --- a/python/paddle/jit/dy2static/convert_call_func.py +++ b/python/paddle/jit/dy2static/convert_call_func.py @@ -339,6 +339,6 @@ def convert_call(func): ) return func - if func_self: + if func_self is not None: converted_call = functools.partial(converted_call, func_self) return converted_call diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 8e2ea05004e8021ef26a7e67e1ea523a2c703547..c426ebcd32c287467d6170ba459c2843918430c2 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -25,6 +25,11 @@ class CallNotExist(paddle.nn.Layer): return paddle.nn.not_exist_api +class CallableList(list): + def __call__(self, x): + return x + + class ForwardNotExist(paddle.nn.Layer): def forward(self): return 0 @@ -51,6 +56,14 @@ class TestConvertCall(unittest.TestCase): with self.assertRaises(AttributeError): forward_not_exist() + def test_callable_list(self): + @paddle.jit.to_static + def callable_list(x, y): + callable_list = CallableList() + return callable_list(x) + y + + self.assertEqual(callable_list(1, 2), 3) + class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self):