未验证 提交 8f762790 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Reduce Exception Type for Better Error Message (#29268)

Reduce exception type so that if covert_to_static failed, it reports right error message.
上级 61a8f287
...@@ -226,7 +226,7 @@ def convert_call(func): ...@@ -226,7 +226,7 @@ def convert_call(func):
# So descriptor mechanism is used to bound `self` instance on function to # So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method. # keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func)) setattr(func, 'forward', forward_func.__get__(func))
except Exception: except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated. # NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self func_self = None if func_self else func_self
converted_call = func converted_call = func
...@@ -235,7 +235,7 @@ def convert_call(func): ...@@ -235,7 +235,7 @@ def convert_call(func):
call_func = func.__class__.__call__ call_func = func.__class__.__call__
converted_call = convert_to_static(call_func) converted_call = convert_to_static(call_func)
func_self = func func_self = func
except Exception: except (IOError, OSError, TypeError):
# NOTE: # NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`, # If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed # it doesn't need to be transformed
......
...@@ -17,6 +17,39 @@ import paddle ...@@ -17,6 +17,39 @@ import paddle
import unittest import unittest
class CallNotExist(paddle.nn.Layer):
def __call__(self):
# call a non-exist API to trigger exception
return paddle.nn.not_exist_api
class ForwardNotExist(paddle.nn.Layer):
def forward(self):
return 0
net = ForwardNotExist()
setattr(net, "forward", "A string so that convert forward will fail")
class TestConvertCall(unittest.TestCase):
def test_class_exception(self):
@paddle.jit.to_static
def call_not_exist():
net = CallNotExist()
return net()
with self.assertRaises(AttributeError):
call_not_exist()
@paddle.jit.to_static
def forward_not_exist():
return net()
with self.assertRaises(TypeError):
forward_not_exist()
class TestConvertShapeCompare(unittest.TestCase): class TestConvertShapeCompare(unittest.TestCase):
def test_non_variable(self): def test_non_variable(self):
self.assertEqual( self.assertEqual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册