From 981244cfab1c767e541bc01c14b01f96cd16831f Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 4 Dec 2020 15:59:17 +0800 Subject: [PATCH] [Dy2stat] Reduce Exception Type for Better Error Message (#29268) (#29363) Reduce exception type so that if covert_to_static failed, it reports right error message. --- .../dygraph_to_static/convert_call_func.py | 4 +-- .../test_convert_operators.py | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index ba011f52a4d..b7d25e2a14b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -226,7 +226,7 @@ def convert_call(func): # So descriptor mechanism is used to bound `self` instance on function to # keep it as bound method. setattr(func, 'forward', forward_func.__get__(func)) - except Exception: + except (IOError, OSError, TypeError): # NOTE: func.forward may have been decorated. func_self = None if func_self else func_self converted_call = func @@ -235,7 +235,7 @@ def convert_call(func): call_func = func.__class__.__call__ converted_call = convert_to_static(call_func) func_self = func - except Exception: + except (IOError, OSError, TypeError): # NOTE: # If `func` is a class which is being initialized, for example `convert_call(Foo)()`, # it doesn't need to be transformed diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 16ed8670da4..28c5d220213 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -17,6 +17,39 @@ import paddle import unittest +class CallNotExist(paddle.nn.Layer): + def __call__(self): + # call a non-exist API to trigger exception + return paddle.nn.not_exist_api + + +class ForwardNotExist(paddle.nn.Layer): + def forward(self): + return 0 + + +net = ForwardNotExist() +setattr(net, "forward", "A string so that convert forward will fail") + + +class TestConvertCall(unittest.TestCase): + def test_class_exception(self): + @paddle.jit.to_static + def call_not_exist(): + net = CallNotExist() + return net() + + with self.assertRaises(AttributeError): + call_not_exist() + + @paddle.jit.to_static + def forward_not_exist(): + return net() + + with self.assertRaises(TypeError): + forward_not_exist() + + class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self): self.assertEqual( -- GitLab