diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index ba011f52a4d421a06b97b7f32d3aec5958ba8580..b7d25e2a14b49166e3ea8ad5e6d63f75ef2517a7 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -226,7 +226,7 @@ def convert_call(func): # So descriptor mechanism is used to bound `self` instance on function to # keep it as bound method. setattr(func, 'forward', forward_func.__get__(func)) - except Exception: + except (IOError, OSError, TypeError): # NOTE: func.forward may have been decorated. func_self = None if func_self else func_self converted_call = func @@ -235,7 +235,7 @@ def convert_call(func): call_func = func.__class__.__call__ converted_call = convert_to_static(call_func) func_self = func - except Exception: + except (IOError, OSError, TypeError): # NOTE: # If `func` is a class which is being initialized, for example `convert_call(Foo)()`, # it doesn't need to be transformed diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 16ed8670da4bc8362c783694974cee17a47ed477..28c5d220213f1a47f76367cfc27c3e73a96562dd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -17,6 +17,39 @@ import paddle import unittest +class CallNotExist(paddle.nn.Layer): + def __call__(self): + # call a non-exist API to trigger exception + return paddle.nn.not_exist_api + + +class ForwardNotExist(paddle.nn.Layer): + def forward(self): + return 0 + + +net = ForwardNotExist() +setattr(net, "forward", "A string so that convert forward will fail") + + +class TestConvertCall(unittest.TestCase): + def test_class_exception(self): + @paddle.jit.to_static + def call_not_exist(): + net = CallNotExist() + return net() + + with self.assertRaises(AttributeError): + call_not_exist() + + @paddle.jit.to_static + def forward_not_exist(): + return net() + + with self.assertRaises(TypeError): + forward_not_exist() + + class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self): self.assertEqual(